From 4886396044356d2676a77c3900af796fe7641f42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Wed, 4 Dec 2024 02:09:07 +0000 Subject: [PATCH] implement client --- src/client.rs | 234 ++++++++++++++++++++++++++------------- src/error.rs | 9 ++ src/jabber.rs | 21 +++- src/lib.rs | 8 +- src/stanza/client/mod.rs | 27 +++-- 5 files changed, 213 insertions(+), 86 deletions(-) diff --git a/src/client.rs b/src/client.rs index 2908346..5351b34 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,11 @@ -use std::sync::Arc; +use std::{pin::pin, sync::Arc, task::Poll}; -use futures::{Sink, Stream}; +use futures::{Sink, Stream, StreamExt}; use rsasl::config::SASLConfig; use crate::{ connection::{Tls, Unencrypted}, + jid::ParseError, stanza::{ client::Stanza, sasl::Mechanisms, @@ -15,14 +16,146 @@ use crate::{ // feed it client stanzas, receive client stanzas pub struct JabberClient { - connection: JabberState, + connection: ConnectionState, jid: JID, password: Arc, server: String, } -pub enum JabberState { +impl JabberClient { + pub fn new( + jid: impl TryInto, + password: impl ToString, + ) -> Result { + let jid = jid.try_into()?; + let sasl_config = SASLConfig::with_credentials( + None, + jid.localpart.clone().ok_or(Error::NoLocalpart)?, + password.to_string(), + )?; + Ok(JabberClient { + connection: ConnectionState::Disconnected, + jid: jid.clone(), + password: sasl_config, + server: jid.domainpart, + }) + } + + pub async fn connect(&mut self) -> Result<()> { + match &self.connection { + ConnectionState::Disconnected => { + self.connection = ConnectionState::Disconnected + .connect(&mut self.jid, self.password.clone(), &mut self.server) + .await?; + Ok(()) + } + ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting), + ConnectionState::Connected(_jabber_stream) => Ok(()), + } + } +} + +impl Stream for JabberClient { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let mut client = pin!(self); + match &mut client.connection { + ConnectionState::Disconnected => Poll::Pending, + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx), + } + } +} + +pub enum ConnectionState { Disconnected, + Connecting(Connecting), + Connected(JabberStream), +} + +impl ConnectionState { + pub async fn connect( + mut self, + jid: &mut JID, + auth: Arc, + server: &mut String, + ) -> Result { + loop { + match self { + ConnectionState::Disconnected => { + self = ConnectionState::Connecting(Connecting::start(&server).await?); + } + ConnectionState::Connecting(connecting) => match connecting { + Connecting::InsecureConnectionEstablised(tcp_stream) => { + self = ConnectionState::Connecting(Connecting::InsecureStreamStarted( + JabberStream::start_stream(tcp_stream, server).await?, + )) + } + Connecting::InsecureStreamStarted(jabber_stream) => { + self = ConnectionState::Connecting(Connecting::InsecureGotFeatures( + jabber_stream.get_features().await?, + )) + } + Connecting::InsecureGotFeatures((features, jabber_stream)) => { + match features.negotiate()? { + Feature::StartTls(_start_tls) => { + self = + ConnectionState::Connecting(Connecting::StartTls(jabber_stream)) + } + // TODO: better error + _ => return Err(Error::TlsRequired), + } + } + Connecting::StartTls(jabber_stream) => { + self = ConnectionState::Connecting(Connecting::ConnectionEstablished( + jabber_stream.starttls(&server).await?, + )) + } + Connecting::ConnectionEstablished(tls_stream) => { + self = ConnectionState::Connecting(Connecting::StreamStarted( + JabberStream::start_stream(tls_stream, server).await?, + )) + } + Connecting::StreamStarted(jabber_stream) => { + self = ConnectionState::Connecting(Connecting::GotFeatures( + jabber_stream.get_features().await?, + )) + } + Connecting::GotFeatures((features, jabber_stream)) => { + match features.negotiate()? { + Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), + Feature::Sasl(mechanisms) => { + self = ConnectionState::Connecting(Connecting::Sasl( + mechanisms, + jabber_stream, + )) + } + Feature::Bind => { + self = ConnectionState::Connecting(Connecting::Bind(jabber_stream)) + } + Feature::Unknown => return Err(Error::Unsupported), + } + } + Connecting::Sasl(mechanisms, jabber_stream) => { + self = ConnectionState::Connecting(Connecting::ConnectionEstablished( + jabber_stream.sasl(mechanisms, auth.clone()).await?, + )) + } + Connecting::Bind(jabber_stream) => { + self = ConnectionState::Connected(jabber_stream.bind(jid).await?) + } + }, + connected => return Ok(connected), + } + } + } +} + +pub enum Connecting { InsecureConnectionEstablised(Unencrypted), InsecureStreamStarted(JabberStream), InsecureGotFeatures((Features, JabberStream)), @@ -32,67 +165,15 @@ pub enum JabberState { GotFeatures((Features, JabberStream)), Sasl(Mechanisms, JabberStream), Bind(JabberStream), - // when it's bound, can stream stanzas and sink stanzas - Bound(JabberStream), } -impl JabberState { - pub async fn advance_state( - self, - jid: &mut JID, - auth: Arc, - server: &mut String, - ) -> Result { - match self { - JabberState::Disconnected => match Connection::connect(server).await? { - Connection::Encrypted(tls_stream) => { - Ok(JabberState::ConnectionEstablished(tls_stream)) - } - Connection::Unencrypted(tcp_stream) => { - Ok(JabberState::InsecureConnectionEstablised(tcp_stream)) - } - }, - JabberState::InsecureConnectionEstablised(tcp_stream) => Ok({ - JabberState::InsecureStreamStarted( - JabberStream::start_stream(tcp_stream, server).await?, - ) - }), - JabberState::InsecureStreamStarted(jabber_stream) => Ok( - JabberState::InsecureGotFeatures(jabber_stream.get_features().await?), - ), - JabberState::InsecureGotFeatures((features, jabber_stream)) => { - match features.negotiate()? { - Feature::StartTls(_start_tls) => Ok(JabberState::StartTls(jabber_stream)), - // TODO: better error - _ => return Err(Error::TlsRequired), - } +impl Connecting { + pub async fn start(server: &str) -> Result { + match Connection::connect(server).await? { + Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)), + Connection::Unencrypted(tcp_stream) => { + Ok(Connecting::InsecureConnectionEstablised(tcp_stream)) } - JabberState::StartTls(jabber_stream) => Ok(JabberState::ConnectionEstablished( - jabber_stream.starttls(server).await?, - )), - JabberState::ConnectionEstablished(tls_stream) => Ok(JabberState::StreamStarted( - JabberStream::start_stream(tls_stream, server).await?, - )), - JabberState::StreamStarted(jabber_stream) => Ok(JabberState::GotFeatures( - jabber_stream.get_features().await?, - )), - JabberState::GotFeatures((features, jabber_stream)) => match features.negotiate()? { - Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), - Feature::Sasl(mechanisms) => { - return Ok(JabberState::Sasl(mechanisms, jabber_stream)) - } - Feature::Bind => return Ok(JabberState::Bind(jabber_stream)), - Feature::Unknown => return Err(Error::Unsupported), - }, - JabberState::Sasl(mechanisms, jabber_stream) => { - return Ok(JabberState::ConnectionEstablished( - jabber_stream.sasl(mechanisms, auth).await?, - )) - } - JabberState::Bind(jabber_stream) => { - Ok(JabberState::Bound(jabber_stream.bind(jid).await?)) - } - JabberState::Bound(jabber_stream) => Ok(JabberState::Bound(jabber_stream)), } } } @@ -126,7 +207,7 @@ impl Features { } } -pub enum InsecureJabberConnection { +pub enum InsecureConnecting { Disconnected, ConnectionEstablished(Connection), PreStarttls(JabberStream), @@ -136,17 +217,6 @@ pub enum InsecureJabberConnection { Bound(JabberStream), } -impl Stream for JabberClient { - type Item = Stanza; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - todo!() - } -} - impl Sink for JabberClient { type Error = Error; @@ -178,3 +248,19 @@ impl Sink for JabberClient { todo!() } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::JabberClient; + use test_log::test; + use tokio::time::sleep; + + #[test(tokio::test)] + async fn login() { + let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); + client.connect().await.unwrap(); + sleep(Duration::from_secs(5)).await + } +} diff --git a/src/error.rs b/src/error.rs index 8cb6496..f117e82 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,8 +13,11 @@ pub enum Error { TlsRequired, AlreadyTls, Unsupported, + NoLocalpart, + AlreadyConnecting, UnexpectedElement(peanuts::Element), XML(peanuts::Error), + Deserialization(peanuts::DeserializeError), SASL(SASLError), JID(ParseError), Authentication(Failure), @@ -34,6 +37,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: peanuts::DeserializeError) -> Self { + Error::Deserialization(e) + } +} + impl From for Error { fn from(e: MechanismNameError) -> Self { Self::SASL(SASLError::MechanismName(e)) diff --git a/src/jabber.rs b/src/jabber.rs index cf90f73..30dc15d 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,8 +1,10 @@ +use std::pin::pin; use std::str::{self, FromStr}; use std::sync::Arc; use async_recursion::async_recursion; -use peanuts::element::IntoElement; +use futures::StreamExt; +use peanuts::element::{FromContent, IntoElement}; use peanuts::{Reader, Writer}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; @@ -13,6 +15,7 @@ use crate::connection::{Tls, Unencrypted}; use crate::error::Error; use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType}; use crate::stanza::client::iq::{Iq, IqType, Query}; +use crate::stanza::client::Stanza; use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use crate::stanza::starttls::{Proceed, StartTls}; use crate::stanza::stream::{Feature, Features, Stream}; @@ -26,6 +29,22 @@ pub struct JabberStream { writer: Writer>, } +impl futures::Stream for JabberStream { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(self).reader.poll_next_unpin(cx).map(|content| { + content.map(|content| -> Result { + let stanza = content.map(|content| Stanza::from_content(content))?; + Ok(stanza?) + }) + }) + } +} + impl JabberStream where S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, diff --git a/src/lib.rs b/src/lib.rs index 9c8d968..e55d3f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,8 +29,8 @@ pub async fn login, P: AsRef>(jid: J, password: P) -> Result< #[cfg(test)] mod tests { - #[tokio::test] - async fn test_login() { - crate::login("test@blos.sm/clown", "slayed").await.unwrap(); - } + // #[tokio::test] + // async fn test_login() { + // crate::login("test@blos.sm/clown", "slayed").await.unwrap(); + // } } diff --git a/src/stanza/client/mod.rs b/src/stanza/client/mod.rs index 25d7b56..2b063d6 100644 --- a/src/stanza/client/mod.rs +++ b/src/stanza/client/mod.rs @@ -1,7 +1,7 @@ use iq::Iq; use message::Message; use peanuts::{ - element::{FromElement, IntoElement}, + element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement}, DeserializeError, }; use presence::Presence; @@ -20,6 +20,18 @@ pub enum Stanza { Presence(Presence), Iq(Iq), Error(StreamError), + OtherContent(Content), +} + +impl FromContent for Stanza { + fn from_content(content: Content) -> peanuts::element::DeserializeResult { + match content { + Content::Element(element) => Ok(Stanza::from_element(element)?), + Content::Text(_) => Ok(Stanza::OtherContent(content)), + Content::PI => Ok(Stanza::OtherContent(content)), + Content::Comment(_) => Ok(Stanza::OtherContent(content)), + } + } } impl FromElement for Stanza { @@ -36,13 +48,14 @@ impl FromElement for Stanza { } } -impl IntoElement for Stanza { - fn builder(&self) -> peanuts::element::ElementBuilder { +impl IntoContent for Stanza { + fn builder(&self) -> peanuts::element::ContentBuilder { match self { - Stanza::Message(message) => message.builder(), - Stanza::Presence(presence) => presence.builder(), - Stanza::Iq(iq) => iq.builder(), - Stanza::Error(error) => error.builder(), + Stanza::Message(message) => ::builder(message), + Stanza::Presence(presence) => ::builder(presence), + Stanza::Iq(iq) => ::builder(iq), + Stanza::Error(error) => ::builder(error), + Stanza::OtherContent(_content) => ContentBuilder::Comment("other-content".to_string()), } } }