From e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Sun, 12 Jan 2025 21:19:07 +0000 Subject: [PATCH] implement stream splitting and closing --- jabber/Cargo.toml | 12 +- jabber/src/client.rs | 287 ++++++----------------- jabber/src/jabber_stream.rs | 118 +++++++++- jabber/src/jabber_stream/bound_stream.rs | 195 +++++---------- stanza/src/client/mod.rs | 1 + 5 files changed, 251 insertions(+), 362 deletions(-) diff --git a/jabber/Cargo.toml b/jabber/Cargo.toml index 68dddd9..d070838 100644 --- a/jabber/Cargo.toml +++ b/jabber/Cargo.toml @@ -12,7 +12,12 @@ async-trait = "0.1.68" lazy_static = "1.4.0" nanoid = "0.4.0" # TODO: remove unneeded features -rsasl = { version = "2.0.1", path = "../../rsasl", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] } +rsasl = { version = "2.0.1", default_features = false, features = [ + "provider_base64", + "plain", + "config_builder", + "scram-sha-1", +] } tokio = { version = "1.28", features = ["full"] } tokio-native-tls = "0.3.1" tracing = "0.1.40" @@ -29,4 +34,7 @@ pin-project = "1.1.7" [dev-dependencies] test-log = { version = "0.2", features = ["trace"] } env_logger = "*" -tracing-subscriber = {version = "0.3", default-features = false, features = ["env-filter", "fmt"]} +tracing-subscriber = { version = "0.3", default-features = false, features = [ + "env-filter", + "fmt", +] } diff --git a/jabber/src/client.rs b/jabber/src/client.rs index 2e59d98..9d32682 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -18,13 +18,13 @@ use tokio::sync::Mutex; use crate::{ connection::{Tls, Unencrypted}, - jabber_stream::bound_stream::BoundJabberStream, + jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream}, Connection, Error, JabberStream, Result, JID, }; // feed it client stanzas, receive client stanzas pub struct JabberClient { - connection: ConnectionState, + connection: Option>, jid: JID, // TODO: have reconnection be handled by another part, so creds don't need to be stored in object password: Arc, @@ -43,7 +43,7 @@ impl JabberClient { password.to_string(), )?; Ok(JabberClient { - connection: ConnectionState::Disconnected, + connection: None, jid: jid.clone(), password: sasl_config, server: jid.domainpart, @@ -56,25 +56,19 @@ impl JabberClient { pub async fn connect(&mut self) -> Result<()> { match &self.connection { - ConnectionState::Disconnected => { - // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?) - // perhaps use take_mut? - self.connection = ConnectionState::Disconnected - .connect(&mut self.jid, self.password.clone(), &mut self.server) - .await?; + Some(_) => Ok(()), + None => { + self.connection = Some( + connect_and_login(&mut self.jid, self.password.clone(), &mut self.server) + .await?, + ); Ok(()) } - ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting), - ConnectionState::Connected(_jabber_stream) => Ok(()), } } - pub(crate) fn inner(self) -> Result> { - match self.connection { - ConnectionState::Disconnected => return Err(Error::Disconnected), - ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), - ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream), - } + pub(crate) fn into_inner(self) -> Result> { + self.connection.ok_or(Error::Disconnected) } // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { @@ -88,203 +82,59 @@ impl JabberClient { // } } -impl Sink for JabberClient { - type Error = Error; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.get_mut().connection.poll_ready_unpin(cx) - } - - fn start_send( - self: std::pin::Pin<&mut Self>, - item: Stanza, - ) -> std::result::Result<(), Self::Error> { - self.get_mut().connection.start_send_unpin(item) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.get_mut().connection.poll_flush_unpin(cx) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.get_mut().connection.poll_flush_unpin(cx) - } -} - -impl Stream for JabberClient { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.get_mut().connection.poll_next_unpin(cx) - } -} - -pub enum ConnectionState { - Disconnected, - Connecting(Connecting), - Connected(BoundJabberStream), -} - -impl Sink for ConnectionState { - type Error = Error; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_ready_unpin(cx) +pub async fn connect_and_login( + jid: &mut JID, + auth: Arc, + server: &mut String, +) -> Result> { + let mut conn_state = Connecting::start(&server).await?; + loop { + match conn_state { + Connecting::InsecureConnectionEstablised(tcp_stream) => { + conn_state = Connecting::InsecureStreamStarted( + JabberStream::start_stream(tcp_stream, server).await?, + ) } - } - } - - fn start_send( - self: std::pin::Pin<&mut Self>, - item: Stanza, - ) -> std::result::Result<(), Self::Error> { - match self.get_mut() { - ConnectionState::Disconnected => Err(Error::Disconnected), - ConnectionState::Connecting(_connecting) => Err(Error::Connecting), - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.start_send_unpin(item) + Connecting::InsecureStreamStarted(jabber_stream) => { + conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?) } - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_flush_unpin(cx) - } - } - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_close_unpin(cx) - } - } - } -} - -impl Stream for ConnectionState { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_next_unpin(cx) - } - } - } -} - -impl ConnectionState { - pub async fn connect( - mut self, - jid: &mut JID, - auth: Arc, - server: &mut String, - ) -> Result { - loop { - match self { - ConnectionState::Disconnected => { - self = ConnectionState::Connecting(Connecting::start(&server).await?); + Connecting::InsecureGotFeatures((features, jabber_stream)) => { + match features.negotiate().ok_or(Error::Negotiation)? { + Feature::StartTls(_start_tls) => { + conn_state = Connecting::StartTls(jabber_stream) + } + // TODO: better error + _ => return Err(Error::TlsRequired), } - ConnectionState::Connecting(connecting) => match connecting { - Connecting::InsecureConnectionEstablised(tcp_stream) => { - self = ConnectionState::Connecting(Connecting::InsecureStreamStarted( - JabberStream::start_stream(tcp_stream, server).await?, - )) + } + Connecting::StartTls(jabber_stream) => { + conn_state = + Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?) + } + Connecting::ConnectionEstablished(tls_stream) => { + conn_state = + Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?) + } + Connecting::StreamStarted(jabber_stream) => { + conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?) + } + Connecting::GotFeatures((features, jabber_stream)) => { + match features.negotiate().ok_or(Error::Negotiation)? { + Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), + Feature::Sasl(mechanisms) => { + conn_state = Connecting::Sasl(mechanisms, jabber_stream) } - Connecting::InsecureStreamStarted(jabber_stream) => { - self = ConnectionState::Connecting(Connecting::InsecureGotFeatures( - jabber_stream.get_features().await?, - )) - } - Connecting::InsecureGotFeatures((features, jabber_stream)) => { - match features.negotiate().ok_or(Error::Negotiation)? { - Feature::StartTls(_start_tls) => { - self = - ConnectionState::Connecting(Connecting::StartTls(jabber_stream)) - } - // TODO: better error - _ => return Err(Error::TlsRequired), - } - } - Connecting::StartTls(jabber_stream) => { - self = ConnectionState::Connecting(Connecting::ConnectionEstablished( - jabber_stream.starttls(&server).await?, - )) - } - Connecting::ConnectionEstablished(tls_stream) => { - self = ConnectionState::Connecting(Connecting::StreamStarted( - JabberStream::start_stream(tls_stream, server).await?, - )) - } - Connecting::StreamStarted(jabber_stream) => { - self = ConnectionState::Connecting(Connecting::GotFeatures( - jabber_stream.get_features().await?, - )) - } - Connecting::GotFeatures((features, jabber_stream)) => { - match features.negotiate().ok_or(Error::Negotiation)? { - Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), - Feature::Sasl(mechanisms) => { - self = ConnectionState::Connecting(Connecting::Sasl( - mechanisms, - jabber_stream, - )) - } - Feature::Bind => { - self = ConnectionState::Connecting(Connecting::Bind(jabber_stream)) - } - Feature::Unknown => return Err(Error::Unsupported), - } - } - Connecting::Sasl(mechanisms, jabber_stream) => { - self = ConnectionState::Connecting(Connecting::ConnectionEstablished( - jabber_stream.sasl(mechanisms, auth.clone()).await?, - )) - } - Connecting::Bind(jabber_stream) => { - self = ConnectionState::Connected( - jabber_stream.bind(jid).await?.to_bound_jabber(), - ) - } - }, - connected => return Ok(connected), + Feature::Bind => conn_state = Connecting::Bind(jabber_stream), + Feature::Unknown => return Err(Error::Unsupported), + } + } + Connecting::Sasl(mechanisms, jabber_stream) => { + conn_state = Connecting::ConnectionEstablished( + jabber_stream.sasl(mechanisms, auth.clone()).await?, + ) + } + Connecting::Bind(jabber_stream) => { + return Ok(jabber_stream.bind(jid).await?.to_bound_jabber()); } } } @@ -354,12 +204,12 @@ mod tests { sleep(Duration::from_secs(5)).await; let jid = client.jid.clone(); let server = client.server.clone(); - let (mut write, mut read) = client.split(); + let (mut read, mut write) = client.into_inner().unwrap().split(); tokio::join!( async { write - .send(Stanza::Iq(Iq { + .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s1".to_string(), to: Some(server.clone().try_into().unwrap()), @@ -368,9 +218,10 @@ mod tests { query: Some(Query::Ping(Ping)), errors: Vec::new(), })) - .await; + .await + .unwrap(); write - .send(Stanza::Iq(Iq { + .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s2".to_string(), to: Some(server.clone().try_into().unwrap()), @@ -379,11 +230,13 @@ mod tests { query: Some(Query::Ping(Ping)), errors: Vec::new(), })) - .await; + .await + .unwrap(); }, async { - while let Some(stanza) = read.next().await { - info!("{:#?}", stanza); + for _ in 0..2 { + let stanza = read.read::().await.unwrap(); + info!("ping reply: {:#?}", stanza); } } ); diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs index 89890a8..384e6e4 100644 --- a/jabber/src/jabber_stream.rs +++ b/jabber/src/jabber_stream.rs @@ -26,8 +26,103 @@ pub mod bound_stream; // open stream (streams started) pub struct JabberStream { - reader: Reader>, - pub(crate) writer: Writer>, + reader: JabberReader, + writer: JabberWriter, +} + +impl JabberStream { + fn split(self) -> (JabberReader, JabberWriter) { + let reader = self.reader; + let writer = self.writer; + (reader, writer) + } +} + +pub struct JabberReader(Reader>); + +impl JabberReader { + // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner + fn new(reader: Reader>) -> Self { + Self(reader) + } + + fn unsplit(self, writer: JabberWriter) -> JabberStream { + JabberStream { + reader: self, + writer, + } + } + + fn into_inner(self) -> Reader> { + self.0 + } +} + +impl JabberReader +where + S: AsyncRead + Unpin, +{ + pub async fn try_close(&mut self) -> Result<()> { + self.read_end_tag().await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberReader { + type Target = Reader>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +pub struct JabberWriter(Writer>); + +impl JabberWriter { + fn new(writer: Writer>) -> Self { + Self(writer) + } + + fn unsplit(self, reader: JabberReader) -> JabberStream { + JabberStream { + reader, + writer: self, + } + } + + fn into_inner(self) -> Writer> { + self.0 + } +} + +impl JabberWriter +where + S: AsyncWrite + Unpin + Send, +{ + pub async fn try_close(&mut self) -> Result<()> { + self.write_end().await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberWriter { + type Target = Writer>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberWriter { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } impl JabberStream @@ -119,8 +214,8 @@ where } } } - let writer = self.writer.into_inner(); - let reader = self.reader.into_inner(); + let writer = self.writer.into_inner().into_inner(); + let reader = self.reader.into_inner().into_inner(); let stream = reader.unsplit(writer); Ok(stream) } @@ -223,8 +318,8 @@ where pub async fn start_stream(connection: S, server: &mut String) -> Result { // client to server let (reader, writer) = tokio::io::split(connection); - let mut reader = Reader::new(reader); - let mut writer = Writer::new(writer); + let mut reader = JabberReader::new(Reader::new(reader)); + let mut writer = JabberWriter::new(Writer::new(writer)); // declaration writer.write_declaration(XML_VERSION).await?; @@ -262,7 +357,10 @@ where } pub fn into_inner(self) -> S { - self.reader.into_inner().unsplit(self.writer.into_inner()) + self.reader + .into_inner() + .into_inner() + .unsplit(self.writer.into_inner().into_inner()) } pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { @@ -280,7 +378,11 @@ impl JabberStream { let proceed: Proceed = self.reader.read().await?; debug!("got proceed: {:?}", proceed); let connector = TlsConnector::new().unwrap(); - let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); + let stream = self + .reader + .into_inner() + .into_inner() + .unsplit(self.writer.into_inner().into_inner()); if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) .connect(domain.as_ref(), stream) .await diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs index 627158a..51a1763 100644 --- a/jabber/src/jabber_stream/bound_stream.rs +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -1,128 +1,82 @@ -use std::future::ready; -use std::pin::pin; -use std::pin::Pin; -use std::sync::Arc; -use std::task::Poll; +use std::ops::{Deref, DerefMut}; -use futures::ready; -use futures::FutureExt; -use futures::{sink, stream, Sink, Stream}; use peanuts::{Reader, Writer}; -use pin_project::pin_project; -use stanza::client::Stanza; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; -use tokio::sync::Mutex; -use tokio::task::JoinHandle; use crate::Error; -use super::JabberStream; +use super::{JabberReader, JabberStream, JabberWriter}; -#[pin_project] -pub struct BoundJabberStream +pub struct BoundJabberStream(JabberStream); + +impl Deref for BoundJabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { - reader: Arc>>>, - writer: Arc>>>, - write_handle: Option>>, - read_handle: Option>>, + type Target = JabberStream; + + fn deref(&self) -> &Self::Target { + &self.0 + } } -impl BoundJabberStream +impl DerefMut for BoundJabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { - // TODO: look into biased mutex, to close stream ASAP - // TODO: put into connection - // pub async fn close_stream(self) -> Result, Error> { - // let reader = self.reader.lock().await.into_self(); - // let writer = self.writer.lock().await.into_self(); - // // TODO: writer - // return Ok(JabberStream { reader, writer }); - // } -} - -pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} - -impl Sink for BoundJabberStream -where - S: AsyncWrite + AsyncRead + Unpin + Send + 'static, -{ - type Error = Error; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.poll_flush(cx) - } - - fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { - let this = self.project(); - if let Some(_write_handle) = this.write_handle { - panic!("start_send called without poll_ready") - } else { - // TODO: switch to buffer of one rather than thread spawning and joining - *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item))); - Ok(()) - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() { - match ready!(join_handle.poll_unpin(cx)) { - Ok(state) => { - *this.write_handle = None; - state - } - Err(err) => { - *this.write_handle = None; - Err(err.into()) - } - } - } else { - Ok(()) - }) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.poll_flush(cx) + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } -impl Stream for BoundJabberStream -where - S: AsyncWrite + AsyncRead + Unpin + Send + 'static, -{ - type Item = Result; +impl BoundJabberStream { + pub fn split(self) -> (BoundJabberReader, BoundJabberWriter) { + let (reader, writer) = self.0.split(); + (BoundJabberReader(reader), BoundJabberWriter(writer)) + } +} - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); +pub struct BoundJabberReader(JabberReader); - loop { - if let Some(join_handle) = this.read_handle.as_mut() { - let stanza = ready!(join_handle.poll_unpin(cx)); - if let Ok(item) = stanza { - *this.read_handle = None; - return Poll::Ready(Some(item)); - } else if let Err(err) = stanza { - return Poll::Ready(Some(Err(err.into()))); - } - } else { - *this.read_handle = Some(tokio::spawn(read(this.reader.clone()))) - } - } +impl BoundJabberReader { + pub fn unsplit(self, writer: BoundJabberWriter) -> BoundJabberStream { + BoundJabberStream(self.0.unsplit(writer.0)) + } +} + +impl std::ops::Deref for BoundJabberReader { + type Target = JabberReader; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for BoundJabberReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +pub struct BoundJabberWriter(JabberWriter); + +impl BoundJabberWriter { + pub fn unsplit(self, reader: BoundJabberReader) -> BoundJabberStream { + BoundJabberStream(self.0.unsplit(reader.0)) + } +} + +impl std::ops::Deref for BoundJabberWriter { + type Target = JabberWriter; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for BoundJabberWriter { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -131,35 +85,6 @@ where S: AsyncWrite + AsyncRead + Unpin + Send, { pub fn to_bound_jabber(self) -> BoundJabberStream { - let reader = Arc::new(Mutex::new(self.reader)); - let writer = Arc::new(Mutex::new(self.writer)); - BoundJabberStream { - writer, - reader, - write_handle: None, - read_handle: None, - } + BoundJabberStream(self) } } - -pub async fn write( - writer: Arc>>>, - stanza: Stanza, -) -> Result<(), Error> { - { - let mut writer = writer.lock().await; - writer.write(&stanza).await?; - } - Ok(()) -} - -pub async fn read( - reader: Arc>>>, -) -> Result { - let stanza: Result; - { - let mut reader = reader.lock().await; - stanza = reader.read().await.map_err(|e| e.into()); - } - stanza -} diff --git a/stanza/src/client/mod.rs b/stanza/src/client/mod.rs index e9c336e..11ba616 100644 --- a/stanza/src/client/mod.rs +++ b/stanza/src/client/mod.rs @@ -15,6 +15,7 @@ pub mod presence; pub const XMLNS: &str = "jabber:client"; +/// TODO: End tag #[derive(Debug)] pub enum Stanza { Message(Message),