implement sink and stream with tokio::spawn

This commit is contained in:
cel 🌸 2024-12-22 18:58:28 +00:00
parent 595d165479
commit 6385e43e8c
11 changed files with 336 additions and 101 deletions

View File

@ -1,6 +1,12 @@
use std::{pin::pin, sync::Arc, task::Poll};
use std::{
borrow::Borrow,
future::Future,
pin::pin,
sync::Arc,
task::{ready, Poll},
};
use futures::{Sink, Stream, StreamExt};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig;
use stanza::{
@ -8,9 +14,11 @@ use stanza::{
sasl::Mechanisms,
stream::{Feature, Features},
};
use tokio::sync::Mutex;
use crate::{
connection::{Tls, Unencrypted},
jabber_stream::bound_stream::BoundJabberStream,
Connection, Error, JabberStream, Result, JID,
};
@ -56,7 +64,7 @@ impl JabberClient {
}
}
pub(crate) fn inner(self) -> Result<JabberStream<Tls>> {
pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> {
match self.connection {
ConnectionState::Disconnected => return Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
@ -64,21 +72,137 @@ impl JabberClient {
}
}
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
match &mut self.connection {
ConnectionState::Disconnected => return Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
ConnectionState::Connected(jabber_stream) => {
Ok(jabber_stream.send_stanza(stanza).await?)
}
}
// pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
// match &mut self.connection {
// ConnectionState::Disconnected => return Err(Error::Disconnected),
// ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
// ConnectionState::Connected(jabber_stream) => {
// Ok(jabber_stream.send_stanza(stanza).await?)
// }
// }
// }
}
impl Sink<Stanza> for JabberClient {
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.get_mut().connection.poll_ready_unpin(cx)
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: Stanza,
) -> std::result::Result<(), Self::Error> {
self.get_mut().connection.start_send_unpin(item)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.get_mut().connection.poll_flush_unpin(cx)
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.get_mut().connection.poll_flush_unpin(cx)
}
}
impl Stream for JabberClient {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.get_mut().connection.poll_next_unpin(cx)
}
}
pub enum ConnectionState {
Disconnected,
Connecting(Connecting),
Connected(JabberStream<Tls>),
Connected(BoundJabberStream<Tls>),
}
impl Sink<Stanza> for ConnectionState {
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_ready_unpin(cx)
}
}
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: Stanza,
) -> std::result::Result<(), Self::Error> {
match self.get_mut() {
ConnectionState::Disconnected => Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => Err(Error::Connecting),
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.start_send_unpin(item)
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_flush_unpin(cx)
}
}
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_close_unpin(cx)
}
}
}
}
impl Stream for ConnectionState {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_next_unpin(cx)
}
}
}
}
impl ConnectionState {
@ -150,7 +274,9 @@ impl ConnectionState {
))
}
Connecting::Bind(jabber_stream) => {
self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
self = ConnectionState::Connected(
jabber_stream.bind(jid).await?.to_bound_jabber(),
)
}
},
connected => return Ok(connected),
@ -194,11 +320,20 @@ pub enum InsecureConnecting {
#[cfg(test)]
mod tests {
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use super::JabberClient;
use futures::{SinkExt, StreamExt};
use stanza::{
client::{
iq::{Iq, IqType, Query},
Stanza,
},
xep_0199::Ping,
};
use test_log::test;
use tokio::time::sleep;
use tokio::{sync::Mutex, time::sleep};
use tracing::info;
#[test(tokio::test)]
async fn login() {
@ -206,4 +341,50 @@ mod tests {
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await
}
#[test(tokio::test)]
async fn ping_parallel() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await;
let jid = client.jid.clone();
let server = client.server.clone();
let mut client = Arc::new(Mutex::new(client));
tokio::join!(
async {
let mut client = client.lock().await;
client
.send(Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s1".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await;
},
async {
let mut client = client.lock().await;
client
.send(Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s2".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await;
},
async {
while let Some(stanza) = client.lock().await.next().await {
info!("{:#?}", stanza);
}
}
);
}
}

View File

@ -5,6 +5,7 @@ use rsasl::mechname::MechanismNameError;
use stanza::client::error::Error as ClientError;
use stanza::sasl::Failure;
use stanza::stream::Error as StreamError;
use tokio::task::JoinError;
#[derive(Debug)]
pub enum Error {
@ -28,6 +29,7 @@ pub enum Error {
MissingError,
Disconnected,
Connecting,
JoinError(JoinError),
}
#[derive(Debug)]
@ -42,6 +44,12 @@ impl From<rsasl::prelude::SASLError> for Error {
}
}
impl From<JoinError> for Error {
fn from(e: JoinError) -> Self {
Self::JoinError(e)
}
}
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)

View File

@ -27,7 +27,7 @@ pub mod bound_stream;
// open stream (streams started)
pub struct JabberStream<S> {
reader: Reader<ReadHalf<S>>,
writer: Writer<WriteHalf<S>>,
pub(crate) writer: Writer<WriteHalf<S>>,
}
impl<S> JabberStream<S>
@ -368,12 +368,12 @@ mod tests {
async fn sink() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
let stream = client.inner().unwrap();
let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move {
stream.writer.write(&stanza).await?;
Ok::<JabberStream<Tls>, Error>(stream)
});
todo!()
// let stream = client.inner().unwrap();
// let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move {
// stream.writer.write(&stanza).await?;
// Ok::<JabberStream<Tls>, Error>(stream)
// });
// todo!()
// let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
// .await
// .unwrap()

View File

@ -1,70 +1,71 @@
use std::future::ready;
use std::pin::pin;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use futures::ready;
use futures::FutureExt;
use futures::{sink, stream, Sink, Stream};
use peanuts::{Reader, Writer};
use pin_project::pin_project;
use stanza::client::Stanza;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use crate::Error;
use super::JabberStream;
#[pin_project]
pub struct BoundJabberStream<R, W, S>
pub struct BoundJabberStream<S>
where
R: Stream,
W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>,
writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>,
stream: R,
sink: W,
reader: Arc<Mutex<Reader<ReadHalf<S>>>>,
writer: Arc<Mutex<Writer<WriteHalf<S>>>>,
write_handle: Option<JoinHandle<Result<(), Error>>>,
read_handle: Option<JoinHandle<Result<Stanza, Error>>>,
}
impl<R, W, S> BoundJabberStream<R, W, S>
impl<S> BoundJabberStream<S>
where
R: Stream,
W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
// TODO: look into biased mutex, to close stream ASAP
pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
if let Some(reader) = self.reader.lock().await.take() {
if let Some(writer) = self.writer.lock().await.take() {
// TODO: writer </stream:stream>
return Ok(JabberStream { reader, writer });
}
}
return Err(Error::StreamClosed);
}
// TODO: put into connection
// pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
// let reader = self.reader.lock().await.into_self();
// let writer = self.writer.lock().await.into_self();
// // TODO: writer </stream:stream>
// return Ok(JabberStream { reader, writer });
// }
}
pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}
impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S>
impl<S> Sink<Stanza> for BoundJabberStream<S>
where
R: Stream,
W: Sink<Stanza> + Unpin,
S: AsyncWrite + AsyncRead + Unpin + Send,
S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
{
type Error = <W as Sink<Stanza>>::Error;
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.project();
pin!(this.sink).poll_ready(cx)
self.poll_flush(cx)
}
fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
let this = self.project();
pin!(this.sink).start_send(item)
if let Some(_write_handle) = this.write_handle {
panic!("start_send called without poll_ready")
} else {
*this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item)));
Ok(())
}
}
fn poll_flush(
@ -72,32 +73,55 @@ where
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.project();
pin!(this.sink).poll_flush(cx)
Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() {
match ready!(join_handle.poll_unpin(cx)) {
Ok(state) => {
*this.write_handle = None;
state
}
Err(err) => {
*this.write_handle = None;
Err(err.into())
}
}
} else {
Ok(())
})
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.project();
pin!(this.sink).poll_close(cx)
self.poll_flush(cx)
}
}
impl<R, W, S> Stream for BoundJabberStream<R, W, S>
impl<S> Stream for BoundJabberStream<S>
where
R: Stream + Unpin,
W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
{
type Item = <R as Stream>::Item;
type Item = Result<Stanza, Error>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
pin!(this.stream).poll_next(cx)
loop {
if let Some(join_handle) = this.read_handle.as_mut() {
let stanza = ready!(join_handle.poll_unpin(cx));
if let Ok(item) = stanza {
*this.read_handle = None;
return Poll::Ready(Some(item));
} else if let Err(err) = stanza {
return Poll::Ready(Some(Err(err.into())));
}
} else {
*this.read_handle = Some(tokio::spawn(read(this.reader.clone())))
}
}
}
}
@ -105,49 +129,36 @@ impl<S> JabberStream<S>
where
S: AsyncWrite + AsyncRead + Unpin + Send,
{
pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> {
let reader = Arc::new(Mutex::new(Some(self.reader)));
let writer = Arc::new(Mutex::new(Some(self.writer)));
let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move {
write(writer, s).await
});
let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await });
pub fn to_bound_jabber(self) -> BoundJabberStream<S> {
let reader = Arc::new(Mutex::new(self.reader));
let writer = Arc::new(Mutex::new(self.writer));
BoundJabberStream {
sink,
stream,
writer,
reader,
write_handle: None,
read_handle: None,
}
}
}
pub async fn write<W: AsyncWrite + Unpin + Send>(
writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>,
writer: Arc<Mutex<Writer<WriteHalf<W>>>>,
stanza: Stanza,
) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> {
) -> Result<(), Error> {
{
if let Some(writer) = writer.lock().await.as_mut() {
writer.write(&stanza).await?;
} else {
return Err(Error::StreamClosed);
}
let mut writer = writer.lock().await;
writer.write(&stanza).await?;
}
Ok(writer)
Ok(())
}
pub async fn read<R: AsyncRead + Unpin + Send>(
reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
) -> Option<(
Result<Stanza, Error>,
Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
)> {
reader: Arc<Mutex<Reader<ReadHalf<R>>>>,
) -> Result<Stanza, Error> {
let stanza: Result<Stanza, Error>;
{
if let Some(reader) = reader.lock().await.as_mut() {
stanza = reader.read().await.map_err(|e| e.into());
} else {
stanza = Err(Error::StreamClosed)
};
let mut reader = reader.lock().await;
stanza = reader.read().await.map_err(|e| e.into());
}
Some((stanza, reader))
stanza
}

View File

@ -6,7 +6,7 @@ use peanuts::{
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-bind";
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Bind {
pub r#type: Option<BindType>,
}
@ -28,7 +28,7 @@ impl IntoElement for Bind {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum BindType {
Resource(ResourceType),
Jid(FullJidType),
@ -56,7 +56,7 @@ impl IntoElement for BindType {
}
// minLength 8 maxLength 3071
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct FullJidType(pub JID);
impl FromElement for FullJidType {
@ -77,7 +77,7 @@ impl IntoElement for FullJidType {
}
// minLength 1 maxLength 1023
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct ResourceType(pub String);
impl FromElement for ResourceType {

View File

@ -9,10 +9,12 @@ use peanuts::{
use crate::{
bind::{self, Bind},
client::error::Error,
xep_0199::{self, Ping},
};
use super::XMLNS;
#[derive(Debug)]
pub struct Iq {
pub from: Option<JID>,
pub id: String,
@ -25,9 +27,10 @@ pub struct Iq {
pub errors: Vec<Error>,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum Query {
Bind(Bind),
Ping(Ping),
Unsupported,
}
@ -35,6 +38,7 @@ impl FromElement for Query {
fn from_element(element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
match element.identify() {
(Some(bind::XMLNS), "bind") => Ok(Query::Bind(Bind::from_element(element)?)),
(Some(xep_0199::XMLNS), "ping") => Ok(Query::Ping(Ping::from_element(element)?)),
_ => Ok(Query::Unsupported),
}
}
@ -44,6 +48,7 @@ impl IntoElement for Query {
fn builder(&self) -> peanuts::element::ElementBuilder {
match self {
Query::Bind(bind) => bind.builder(),
Query::Ping(ping) => ping.builder(),
// TODO: consider what to do if attempt to serialize unsupported
Query::Unsupported => todo!(),
}
@ -88,7 +93,7 @@ impl IntoElement for Iq {
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum IqType {
Error,
Get,

View File

@ -8,6 +8,7 @@ use peanuts::{
use super::XMLNS;
#[derive(Debug)]
pub struct Message {
from: Option<JID>,
id: Option<String>,
@ -69,7 +70,7 @@ impl IntoElement for Message {
}
}
#[derive(Default, PartialEq, Eq, Copy, Clone)]
#[derive(Default, PartialEq, Eq, Copy, Clone, Debug)]
pub enum MessageType {
Chat,
Error,
@ -106,7 +107,7 @@ impl ToString for MessageType {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Body {
lang: Option<String>,
body: Option<String>,
@ -132,7 +133,7 @@ impl IntoElement for Body {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Subject {
lang: Option<String>,
subject: Option<String>,
@ -158,7 +159,7 @@ impl IntoElement for Subject {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Thread {
parent: Option<String>,
thread: Option<String>,

View File

@ -15,6 +15,7 @@ pub mod presence;
pub const XMLNS: &str = "jabber:client";
#[derive(Debug)]
pub enum Stanza {
Message(Message),
Presence(Presence),

View File

@ -8,6 +8,7 @@ use peanuts::{
use super::{error::Error, XMLNS};
#[derive(Debug)]
pub struct Presence {
from: Option<JID>,
id: Option<String>,
@ -70,7 +71,7 @@ impl IntoElement for Presence {
pub enum Other {}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub enum PresenceType {
Error,
Probe,
@ -112,7 +113,7 @@ impl ToString for PresenceType {
}
}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub enum Show {
Away,
Chat,
@ -160,7 +161,7 @@ impl ToString for Show {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Status {
lang: Option<String>,
status: String1024,
@ -188,7 +189,7 @@ impl IntoElement for Status {
// TODO: enforce?
/// minLength 1 maxLength 1024
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct String1024(pub String);
impl FromStr for String1024 {
@ -206,7 +207,7 @@ impl ToString for String1024 {
}
// xs:byte
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug)]
pub struct Priority(pub i8);
impl FromElement for Priority {

View File

@ -7,5 +7,6 @@ pub mod stanza_error;
pub mod starttls;
pub mod stream;
pub mod stream_error;
pub mod xep_0199;
pub static XML_VERSION: VersionInfo = VersionInfo::One;

26
stanza/src/xep_0199.rs Normal file
View File

@ -0,0 +1,26 @@
use peanuts::{
element::{FromElement, IntoElement},
Element,
};
pub const XMLNS: &str = "urn:xmpp:ping";
#[derive(Clone, Copy, Debug)]
pub struct Ping;
impl FromElement for Ping {
fn from_element(element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("ping")?;
element.check_namespace(XMLNS)?;
element.no_more_content()?;
Ok(Ping)
}
}
impl IntoElement for Ping {
fn builder(&self) -> peanuts::element::ElementBuilder {
Element::builder("ping", Some(XMLNS))
}
}