From 21f10a0b43c4ab1429b274b386065c023c661ab0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Wed, 4 Dec 2024 17:38:36 +0000 Subject: [PATCH] implement send_stanza --- README.md | 10 ++++ src/client.rs | 44 +++++------------- src/error.rs | 4 ++ src/jabber.rs | 94 ++++---------------------------------- src/stanza/stream.rs | 1 + src/stanza/stream_error.rs | 4 +- 6 files changed, 39 insertions(+), 118 deletions(-) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..195860c --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +# jabber client library + +## TODO: + +- [ ] error states for all negotiation parts +- [ ] better errors +- [ ] rename structs +- [ ] remove commented code +- [ ] asynchronous connect (with take_mut?) +- [ ] split into separate crates: stanza, jabber, and luz diff --git a/src/client.rs b/src/client.rs index 5351b34..e94008d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -44,6 +44,8 @@ impl JabberClient { pub async fn connect(&mut self) -> Result<()> { match &self.connection { ConnectionState::Disconnected => { + // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?) + // perhaps use take_mut? self.connection = ConnectionState::Disconnected .connect(&mut self.jid, self.password.clone(), &mut self.server) .await?; @@ -53,6 +55,16 @@ impl JabberClient { ConnectionState::Connected(_jabber_stream) => Ok(()), } } + + pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { + match &mut self.connection { + ConnectionState::Disconnected => return Err(Error::Disconnected), + ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), + ConnectionState::Connected(jabber_stream) => { + Ok(jabber_stream.send_stanza(stanza).await?) + } + } + } } impl Stream for JabberClient { @@ -217,38 +229,6 @@ pub enum InsecureConnecting { Bound(JabberStream), } -impl Sink for JabberClient { - type Error = Error; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - todo!() - } - - fn start_send( - self: std::pin::Pin<&mut Self>, - item: Stanza, - ) -> std::result::Result<(), Self::Error> { - todo!() - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - todo!() - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - todo!() - } -} - #[cfg(test)] mod tests { use std::time::Duration; diff --git a/src/error.rs b/src/error.rs index f117e82..8875ebb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,7 @@ use std::str::Utf8Error; use rsasl::mechname::MechanismNameError; use crate::stanza::client::error::Error as ClientError; +use crate::stanza::stream::Error as StreamError; use crate::{jid::ParseError, stanza::sasl::Failure}; #[derive(Debug)] @@ -22,7 +23,10 @@ pub enum Error { JID(ParseError), Authentication(Failure), ClientError(ClientError), + StreamError(StreamError), MissingError, + Disconnected, + Connecting, } #[derive(Debug)] diff --git a/src/jabber.rs b/src/jabber.rs index 30dc15d..8ee45b5 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -50,6 +50,7 @@ where S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, JabberStream: std::fmt::Debug, { + #[instrument] pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc) -> Result { let sasl = SASLClient::new(sasl_config); let mut offered_mechs: Vec<&Mechname> = Vec::new(); @@ -139,6 +140,7 @@ where Ok(stream) } + #[instrument] pub async fn bind(mut self, jid: &mut JID) -> Result { let iq_id = nanoid::nanoid!(); if let Some(resource) = &jid.resourcepart { @@ -266,6 +268,7 @@ where Ok(Self { reader, writer }) } + #[instrument] pub async fn get_features(mut self) -> Result<(Features, Self)> { debug!("getting features"); let features: Features = self.reader.read().await?; @@ -276,91 +279,16 @@ where pub fn into_inner(self) -> S { self.reader.into_inner().unsplit(self.writer.into_inner()) } + + pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { + self.writer.write(stanza).await?; + Ok(()) + } } impl JabberStream { - // pub async fn negotiate( - // mut self, - // features: Features, - // ) -> Result { - // // TODO: timeout - // if let Some(Feature::StartTls(_)) = features - // .features - // .iter() - // .find(|feature| matches!(feature, Feature::StartTls(_s))) - // { - // return Ok(self); - // } else { - // // TODO: better error - // return Err(Error::TlsRequired); - // } - // } - - // #[async_recursion] - // pub async fn negotiate_tls_optional(mut self) -> Result { - // self.start_stream().await?; - // // TODO: timeout - // let features = self.get_features().await?.features; - // if let Some(Feature::StartTls(_)) = features - // .iter() - // .find(|feature| matches!(feature, Feature::StartTls(_s))) - // { - // let jabber = self.starttls().await?; - // let jabber = jabber.negotiate().await?; - // return Ok(Connection::Encrypted(jabber)); - // } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( - // self.auth.clone(), - // features - // .iter() - // .find(|feature| matches!(feature, Feature::Sasl(_))), - // ) { - // self.sasl(mechanisms.clone(), sasl_config).await?; - // let jabber = self.negotiate_tls_optional().await?; - // Ok(jabber) - // } else if let Some(Feature::Bind) = features - // .iter() - // .find(|feature| matches!(feature, Feature::Bind)) - // { - // self.bind().await?; - // Ok(Connection::Unencrypted(self)) - // } else { - // // TODO: better error - // return Err(Error::Negotiation); - // } - // } -} - -impl JabberStream { - // #[async_recursion] - // pub async fn negotiate(mut self) -> Result> { - // self.start_stream().await?; - // let features = self.get_features().await?.features; - - // if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( - // self.auth.clone(), - // features - // .iter() - // .find(|feature| matches!(feature, Feature::Sasl(_))), - // ) { - // // TODO: avoid clone - // self.sasl(mechanisms.clone(), sasl_config).await?; - // let jabber = self.negotiate().await?; - // Ok(jabber) - // } else if let Some(Feature::Bind) = features - // .iter() - // .find(|feature| matches!(feature, Feature::Bind)) - // { - // self.bind().await?; - // Ok(self) - // } else { - // // TODO: better error - // return Err(Error::Negotiation); - // } - // } -} - -impl JabberStream { - pub async fn starttls(mut self, domain: impl AsRef) -> Result { + #[instrument] + pub async fn starttls(mut self, domain: impl AsRef + std::fmt::Debug) -> Result { self.writer .write_full(&StartTls { required: false }) .await?; @@ -372,8 +300,6 @@ impl JabberStream { .connect(domain.as_ref(), stream) .await { - // let (read, write) = tokio::io::split(tlsstream); - // let client = JabberStream::new(read, write); return Ok(tls_stream); } else { return Err(Error::Connection); diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 4f3c435..84d62d9 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -164,6 +164,7 @@ impl FromElement for Feature { } } +#[derive(Debug)] pub struct Error { error: StreamError, text: Option, diff --git a/src/stanza/stream_error.rs b/src/stanza/stream_error.rs index 37db8a1..5ae04a6 100644 --- a/src/stanza/stream_error.rs +++ b/src/stanza/stream_error.rs @@ -5,7 +5,7 @@ use peanuts::{ pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams"; -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum Error { BadFormat, BadNamespacePrefix, @@ -110,7 +110,7 @@ impl IntoElement for Error { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Text { text: Option, lang: Option,