From 6385e43e8ca467e53c6a705a932016c5af75c3a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Sun, 22 Dec 2024 18:58:28 +0000 Subject: [PATCH] implement sink and stream with tokio::spawn --- jabber/src/client.rs | 211 +++++++++++++++++++++-- jabber/src/error.rs | 8 + jabber/src/jabber_stream.rs | 14 +- jabber/src/jabber_stream/bound_stream.rs | 139 ++++++++------- stanza/src/bind.rs | 8 +- stanza/src/client/iq.rs | 9 +- stanza/src/client/message.rs | 9 +- stanza/src/client/mod.rs | 1 + stanza/src/client/presence.rs | 11 +- stanza/src/lib.rs | 1 + stanza/src/xep_0199.rs | 26 +++ 11 files changed, 336 insertions(+), 101 deletions(-) create mode 100644 stanza/src/xep_0199.rs diff --git a/jabber/src/client.rs b/jabber/src/client.rs index c6cab07..32b8f6e 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -1,6 +1,12 @@ -use std::{pin::pin, sync::Arc, task::Poll}; +use std::{ + borrow::Borrow, + future::Future, + pin::pin, + sync::Arc, + task::{ready, Poll}, +}; -use futures::{Sink, Stream, StreamExt}; +use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; use jid::ParseError; use rsasl::config::SASLConfig; use stanza::{ @@ -8,9 +14,11 @@ use stanza::{ sasl::Mechanisms, stream::{Feature, Features}, }; +use tokio::sync::Mutex; use crate::{ connection::{Tls, Unencrypted}, + jabber_stream::bound_stream::BoundJabberStream, Connection, Error, JabberStream, Result, JID, }; @@ -56,7 +64,7 @@ impl JabberClient { } } - pub(crate) fn inner(self) -> Result> { + pub(crate) fn inner(self) -> Result> { match self.connection { ConnectionState::Disconnected => return Err(Error::Disconnected), ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), @@ -64,21 +72,137 @@ impl JabberClient { } } - pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { - match &mut self.connection { - ConnectionState::Disconnected => return Err(Error::Disconnected), - ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), - ConnectionState::Connected(jabber_stream) => { - Ok(jabber_stream.send_stanza(stanza).await?) - } - } + // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { + // match &mut self.connection { + // ConnectionState::Disconnected => return Err(Error::Disconnected), + // ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), + // ConnectionState::Connected(jabber_stream) => { + // Ok(jabber_stream.send_stanza(stanza).await?) + // } + // } + // } +} + +impl Sink for JabberClient { + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.get_mut().connection.poll_ready_unpin(cx) + } + + fn start_send( + self: std::pin::Pin<&mut Self>, + item: Stanza, + ) -> std::result::Result<(), Self::Error> { + self.get_mut().connection.start_send_unpin(item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.get_mut().connection.poll_flush_unpin(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.get_mut().connection.poll_flush_unpin(cx) + } +} + +impl Stream for JabberClient { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.get_mut().connection.poll_next_unpin(cx) } } pub enum ConnectionState { Disconnected, Connecting(Connecting), - Connected(JabberStream), + Connected(BoundJabberStream), +} + +impl Sink for ConnectionState { + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_ready_unpin(cx) + } + } + } + + fn start_send( + self: std::pin::Pin<&mut Self>, + item: Stanza, + ) -> std::result::Result<(), Self::Error> { + match self.get_mut() { + ConnectionState::Disconnected => Err(Error::Disconnected), + ConnectionState::Connecting(_connecting) => Err(Error::Connecting), + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.start_send_unpin(item) + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_flush_unpin(cx) + } + } + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_close_unpin(cx) + } + } + } +} + +impl Stream for ConnectionState { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_next_unpin(cx) + } + } + } } impl ConnectionState { @@ -150,7 +274,9 @@ impl ConnectionState { )) } Connecting::Bind(jabber_stream) => { - self = ConnectionState::Connected(jabber_stream.bind(jid).await?) + self = ConnectionState::Connected( + jabber_stream.bind(jid).await?.to_bound_jabber(), + ) } }, connected => return Ok(connected), @@ -194,11 +320,20 @@ pub enum InsecureConnecting { #[cfg(test)] mod tests { - use std::time::Duration; + use std::{sync::Arc, time::Duration}; use super::JabberClient; + use futures::{SinkExt, StreamExt}; + use stanza::{ + client::{ + iq::{Iq, IqType, Query}, + Stanza, + }, + xep_0199::Ping, + }; use test_log::test; - use tokio::time::sleep; + use tokio::{sync::Mutex, time::sleep}; + use tracing::info; #[test(tokio::test)] async fn login() { @@ -206,4 +341,50 @@ mod tests { client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await } + + #[test(tokio::test)] + async fn ping_parallel() { + let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); + client.connect().await.unwrap(); + sleep(Duration::from_secs(5)).await; + let jid = client.jid.clone(); + let server = client.server.clone(); + let mut client = Arc::new(Mutex::new(client)); + + tokio::join!( + async { + let mut client = client.lock().await; + client + .send(Stanza::Iq(Iq { + from: Some(jid.clone()), + id: "c2s1".to_string(), + to: Some(server.clone().try_into().unwrap()), + r#type: IqType::Get, + lang: None, + query: Some(Query::Ping(Ping)), + errors: Vec::new(), + })) + .await; + }, + async { + let mut client = client.lock().await; + client + .send(Stanza::Iq(Iq { + from: Some(jid.clone()), + id: "c2s2".to_string(), + to: Some(server.clone().try_into().unwrap()), + r#type: IqType::Get, + lang: None, + query: Some(Query::Ping(Ping)), + errors: Vec::new(), + })) + .await; + }, + async { + while let Some(stanza) = client.lock().await.next().await { + info!("{:#?}", stanza); + } + } + ); + } } diff --git a/jabber/src/error.rs b/jabber/src/error.rs index 6671fe6..902061e 100644 --- a/jabber/src/error.rs +++ b/jabber/src/error.rs @@ -5,6 +5,7 @@ use rsasl::mechname::MechanismNameError; use stanza::client::error::Error as ClientError; use stanza::sasl::Failure; use stanza::stream::Error as StreamError; +use tokio::task::JoinError; #[derive(Debug)] pub enum Error { @@ -28,6 +29,7 @@ pub enum Error { MissingError, Disconnected, Connecting, + JoinError(JoinError), } #[derive(Debug)] @@ -42,6 +44,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: JoinError) -> Self { + Self::JoinError(e) + } +} + impl From for Error { fn from(e: peanuts::DeserializeError) -> Self { Error::Deserialization(e) diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs index d981f8f..89890a8 100644 --- a/jabber/src/jabber_stream.rs +++ b/jabber/src/jabber_stream.rs @@ -27,7 +27,7 @@ pub mod bound_stream; // open stream (streams started) pub struct JabberStream { reader: Reader>, - writer: Writer>, + pub(crate) writer: Writer>, } impl JabberStream @@ -368,12 +368,12 @@ mod tests { async fn sink() { let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); client.connect().await.unwrap(); - let stream = client.inner().unwrap(); - let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { - stream.writer.write(&stanza).await?; - Ok::, Error>(stream) - }); - todo!() + // let stream = client.inner().unwrap(); + // let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { + // stream.writer.write(&stanza).await?; + // Ok::, Error>(stream) + // }); + // todo!() // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) // .await // .unwrap() diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs index ca93421..c0d67b0 100644 --- a/jabber/src/jabber_stream/bound_stream.rs +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -1,70 +1,71 @@ +use std::future::ready; use std::pin::pin; use std::pin::Pin; use std::sync::Arc; +use std::task::Poll; +use futures::ready; +use futures::FutureExt; use futures::{sink, stream, Sink, Stream}; use peanuts::{Reader, Writer}; use pin_project::pin_project; use stanza::client::Stanza; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio::sync::Mutex; +use tokio::task::JoinHandle; use crate::Error; use super::JabberStream; #[pin_project] -pub struct BoundJabberStream +pub struct BoundJabberStream where - R: Stream, - W: Sink, S: AsyncWrite + AsyncRead + Unpin + Send, { - reader: Arc>>>>, - writer: Arc>>>>, - stream: R, - sink: W, + reader: Arc>>>, + writer: Arc>>>, + write_handle: Option>>, + read_handle: Option>>, } -impl BoundJabberStream +impl BoundJabberStream where - R: Stream, - W: Sink, S: AsyncWrite + AsyncRead + Unpin + Send, { // TODO: look into biased mutex, to close stream ASAP - pub async fn close_stream(self) -> Result, Error> { - if let Some(reader) = self.reader.lock().await.take() { - if let Some(writer) = self.writer.lock().await.take() { - // TODO: writer - return Ok(JabberStream { reader, writer }); - } - } - return Err(Error::StreamClosed); - } + // TODO: put into connection + // pub async fn close_stream(self) -> Result, Error> { + // let reader = self.reader.lock().await.into_self(); + // let writer = self.writer.lock().await.into_self(); + // // TODO: writer + // return Ok(JabberStream { reader, writer }); + // } } pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} -impl Sink for BoundJabberStream +impl Sink for BoundJabberStream where - R: Stream, - W: Sink + Unpin, - S: AsyncWrite + AsyncRead + Unpin + Send, + S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { - type Error = >::Error; + type Error = Error; fn poll_ready( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let this = self.project(); - pin!(this.sink).poll_ready(cx) + self.poll_flush(cx) } fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { let this = self.project(); - pin!(this.sink).start_send(item) + if let Some(_write_handle) = this.write_handle { + panic!("start_send called without poll_ready") + } else { + *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item))); + Ok(()) + } } fn poll_flush( @@ -72,32 +73,55 @@ where cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); - pin!(this.sink).poll_flush(cx) + Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() { + match ready!(join_handle.poll_unpin(cx)) { + Ok(state) => { + *this.write_handle = None; + state + } + Err(err) => { + *this.write_handle = None; + Err(err.into()) + } + } + } else { + Ok(()) + }) } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let this = self.project(); - pin!(this.sink).poll_close(cx) + self.poll_flush(cx) } } -impl Stream for BoundJabberStream +impl Stream for BoundJabberStream where - R: Stream + Unpin, - W: Sink, - S: AsyncWrite + AsyncRead + Unpin + Send, + S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { - type Item = ::Item; + type Item = Result; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); - pin!(this.stream).poll_next(cx) + + loop { + if let Some(join_handle) = this.read_handle.as_mut() { + let stanza = ready!(join_handle.poll_unpin(cx)); + if let Ok(item) = stanza { + *this.read_handle = None; + return Poll::Ready(Some(item)); + } else if let Err(err) = stanza { + return Poll::Ready(Some(Err(err.into()))); + } + } else { + *this.read_handle = Some(tokio::spawn(read(this.reader.clone()))) + } + } } } @@ -105,49 +129,36 @@ impl JabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { - pub fn to_bound_jabber(self) -> BoundJabberStream, S> { - let reader = Arc::new(Mutex::new(Some(self.reader))); - let writer = Arc::new(Mutex::new(Some(self.writer))); - let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move { - write(writer, s).await - }); - let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await }); + pub fn to_bound_jabber(self) -> BoundJabberStream { + let reader = Arc::new(Mutex::new(self.reader)); + let writer = Arc::new(Mutex::new(self.writer)); BoundJabberStream { - sink, - stream, writer, reader, + write_handle: None, + read_handle: None, } } } pub async fn write( - writer: Arc>>>>, + writer: Arc>>>, stanza: Stanza, -) -> Result>>>>, Error> { +) -> Result<(), Error> { { - if let Some(writer) = writer.lock().await.as_mut() { - writer.write(&stanza).await?; - } else { - return Err(Error::StreamClosed); - } + let mut writer = writer.lock().await; + writer.write(&stanza).await?; } - Ok(writer) + Ok(()) } pub async fn read( - reader: Arc>>>>, -) -> Option<( - Result, - Arc>>>>, -)> { + reader: Arc>>>, +) -> Result { let stanza: Result; { - if let Some(reader) = reader.lock().await.as_mut() { - stanza = reader.read().await.map_err(|e| e.into()); - } else { - stanza = Err(Error::StreamClosed) - }; + let mut reader = reader.lock().await; + stanza = reader.read().await.map_err(|e| e.into()); } - Some((stanza, reader)) + stanza } diff --git a/stanza/src/bind.rs b/stanza/src/bind.rs index 155fd1b..63644b1 100644 --- a/stanza/src/bind.rs +++ b/stanza/src/bind.rs @@ -6,7 +6,7 @@ use peanuts::{ pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-bind"; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Bind { pub r#type: Option, } @@ -28,7 +28,7 @@ impl IntoElement for Bind { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum BindType { Resource(ResourceType), Jid(FullJidType), @@ -56,7 +56,7 @@ impl IntoElement for BindType { } // minLength 8 maxLength 3071 -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct FullJidType(pub JID); impl FromElement for FullJidType { @@ -77,7 +77,7 @@ impl IntoElement for FullJidType { } // minLength 1 maxLength 1023 -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ResourceType(pub String); impl FromElement for ResourceType { diff --git a/stanza/src/client/iq.rs b/stanza/src/client/iq.rs index 388979e..6ee80ea 100644 --- a/stanza/src/client/iq.rs +++ b/stanza/src/client/iq.rs @@ -9,10 +9,12 @@ use peanuts::{ use crate::{ bind::{self, Bind}, client::error::Error, + xep_0199::{self, Ping}, }; use super::XMLNS; +#[derive(Debug)] pub struct Iq { pub from: Option, pub id: String, @@ -25,9 +27,10 @@ pub struct Iq { pub errors: Vec, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum Query { Bind(Bind), + Ping(Ping), Unsupported, } @@ -35,6 +38,7 @@ impl FromElement for Query { fn from_element(element: peanuts::Element) -> peanuts::element::DeserializeResult { match element.identify() { (Some(bind::XMLNS), "bind") => Ok(Query::Bind(Bind::from_element(element)?)), + (Some(xep_0199::XMLNS), "ping") => Ok(Query::Ping(Ping::from_element(element)?)), _ => Ok(Query::Unsupported), } } @@ -44,6 +48,7 @@ impl IntoElement for Query { fn builder(&self) -> peanuts::element::ElementBuilder { match self { Query::Bind(bind) => bind.builder(), + Query::Ping(ping) => ping.builder(), // TODO: consider what to do if attempt to serialize unsupported Query::Unsupported => todo!(), } @@ -88,7 +93,7 @@ impl IntoElement for Iq { } } -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum IqType { Error, Get, diff --git a/stanza/src/client/message.rs b/stanza/src/client/message.rs index b9d995f..2337d7b 100644 --- a/stanza/src/client/message.rs +++ b/stanza/src/client/message.rs @@ -8,6 +8,7 @@ use peanuts::{ use super::XMLNS; +#[derive(Debug)] pub struct Message { from: Option, id: Option, @@ -69,7 +70,7 @@ impl IntoElement for Message { } } -#[derive(Default, PartialEq, Eq, Copy, Clone)] +#[derive(Default, PartialEq, Eq, Copy, Clone, Debug)] pub enum MessageType { Chat, Error, @@ -106,7 +107,7 @@ impl ToString for MessageType { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Body { lang: Option, body: Option, @@ -132,7 +133,7 @@ impl IntoElement for Body { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Subject { lang: Option, subject: Option, @@ -158,7 +159,7 @@ impl IntoElement for Subject { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Thread { parent: Option, thread: Option, diff --git a/stanza/src/client/mod.rs b/stanza/src/client/mod.rs index 2b063d6..e9c336e 100644 --- a/stanza/src/client/mod.rs +++ b/stanza/src/client/mod.rs @@ -15,6 +15,7 @@ pub mod presence; pub const XMLNS: &str = "jabber:client"; +#[derive(Debug)] pub enum Stanza { Message(Message), Presence(Presence), diff --git a/stanza/src/client/presence.rs b/stanza/src/client/presence.rs index dd14bff..5354966 100644 --- a/stanza/src/client/presence.rs +++ b/stanza/src/client/presence.rs @@ -8,6 +8,7 @@ use peanuts::{ use super::{error::Error, XMLNS}; +#[derive(Debug)] pub struct Presence { from: Option, id: Option, @@ -70,7 +71,7 @@ impl IntoElement for Presence { pub enum Other {} -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub enum PresenceType { Error, Probe, @@ -112,7 +113,7 @@ impl ToString for PresenceType { } } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub enum Show { Away, Chat, @@ -160,7 +161,7 @@ impl ToString for Show { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Status { lang: Option, status: String1024, @@ -188,7 +189,7 @@ impl IntoElement for Status { // TODO: enforce? /// minLength 1 maxLength 1024 -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct String1024(pub String); impl FromStr for String1024 { @@ -206,7 +207,7 @@ impl ToString for String1024 { } // xs:byte -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct Priority(pub i8); impl FromElement for Priority { diff --git a/stanza/src/lib.rs b/stanza/src/lib.rs index 32716d3..f3b0dca 100644 --- a/stanza/src/lib.rs +++ b/stanza/src/lib.rs @@ -7,5 +7,6 @@ pub mod stanza_error; pub mod starttls; pub mod stream; pub mod stream_error; +pub mod xep_0199; pub static XML_VERSION: VersionInfo = VersionInfo::One; diff --git a/stanza/src/xep_0199.rs b/stanza/src/xep_0199.rs new file mode 100644 index 0000000..9605721 --- /dev/null +++ b/stanza/src/xep_0199.rs @@ -0,0 +1,26 @@ +use peanuts::{ + element::{FromElement, IntoElement}, + Element, +}; + +pub const XMLNS: &str = "urn:xmpp:ping"; + +#[derive(Clone, Copy, Debug)] +pub struct Ping; + +impl FromElement for Ping { + fn from_element(element: peanuts::Element) -> peanuts::element::DeserializeResult { + element.check_name("ping")?; + element.check_namespace(XMLNS)?; + + element.no_more_content()?; + + Ok(Ping) + } +} + +impl IntoElement for Ping { + fn builder(&self) -> peanuts::element::ElementBuilder { + Element::builder("ping", Some(XMLNS)) + } +}