use std::{pin::pin, sync::Arc, task::Poll}; use futures::{Sink, Stream, StreamExt}; use jid::ParseError; use rsasl::config::SASLConfig; use stanza::{ client::Stanza, sasl::Mechanisms, stream::{Feature, Features}, }; use crate::{ connection::{Tls, Unencrypted}, Connection, Error, JabberStream, Result, JID, }; // feed it client stanzas, receive client stanzas pub struct JabberClient { connection: ConnectionState, jid: JID, password: Arc, server: String, } impl JabberClient { pub fn new( jid: impl TryInto, password: impl ToString, ) -> Result { let jid = jid.try_into()?; let sasl_config = SASLConfig::with_credentials( None, jid.localpart.clone().ok_or(Error::NoLocalpart)?, password.to_string(), )?; Ok(JabberClient { connection: ConnectionState::Disconnected, jid: jid.clone(), password: sasl_config, server: jid.domainpart, }) } pub async fn connect(&mut self) -> Result<()> { match &self.connection { ConnectionState::Disconnected => { // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?) // perhaps use take_mut? self.connection = ConnectionState::Disconnected .connect(&mut self.jid, self.password.clone(), &mut self.server) .await?; Ok(()) } ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting), ConnectionState::Connected(_jabber_stream) => Ok(()), } } pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { match &mut self.connection { ConnectionState::Disconnected => return Err(Error::Disconnected), ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), ConnectionState::Connected(jabber_stream) => { Ok(jabber_stream.send_stanza(stanza).await?) } } } } impl Stream for JabberClient { type Item = Result; fn poll_next( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let mut client = pin!(self); match &mut client.connection { ConnectionState::Disconnected => Poll::Pending, ConnectionState::Connecting(_connecting) => Poll::Pending, ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx), } } } pub enum ConnectionState { Disconnected, Connecting(Connecting), Connected(JabberStream), } impl ConnectionState { pub async fn connect( mut self, jid: &mut JID, auth: Arc, server: &mut String, ) -> Result { loop { match self { ConnectionState::Disconnected => { self = ConnectionState::Connecting(Connecting::start(&server).await?); } ConnectionState::Connecting(connecting) => match connecting { Connecting::InsecureConnectionEstablised(tcp_stream) => { self = ConnectionState::Connecting(Connecting::InsecureStreamStarted( JabberStream::start_stream(tcp_stream, server).await?, )) } Connecting::InsecureStreamStarted(jabber_stream) => { self = ConnectionState::Connecting(Connecting::InsecureGotFeatures( jabber_stream.get_features().await?, )) } Connecting::InsecureGotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => { self = ConnectionState::Connecting(Connecting::StartTls(jabber_stream)) } // TODO: better error _ => return Err(Error::TlsRequired), } } Connecting::StartTls(jabber_stream) => { self = ConnectionState::Connecting(Connecting::ConnectionEstablished( jabber_stream.starttls(&server).await?, )) } Connecting::ConnectionEstablished(tls_stream) => { self = ConnectionState::Connecting(Connecting::StreamStarted( JabberStream::start_stream(tls_stream, server).await?, )) } Connecting::StreamStarted(jabber_stream) => { self = ConnectionState::Connecting(Connecting::GotFeatures( jabber_stream.get_features().await?, )) } Connecting::GotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), Feature::Sasl(mechanisms) => { self = ConnectionState::Connecting(Connecting::Sasl( mechanisms, jabber_stream, )) } Feature::Bind => { self = ConnectionState::Connecting(Connecting::Bind(jabber_stream)) } Feature::Unknown => return Err(Error::Unsupported), } } Connecting::Sasl(mechanisms, jabber_stream) => { self = ConnectionState::Connecting(Connecting::ConnectionEstablished( jabber_stream.sasl(mechanisms, auth.clone()).await?, )) } Connecting::Bind(jabber_stream) => { self = ConnectionState::Connected(jabber_stream.bind(jid).await?) } }, connected => return Ok(connected), } } } } pub enum Connecting { InsecureConnectionEstablised(Unencrypted), InsecureStreamStarted(JabberStream), InsecureGotFeatures((Features, JabberStream)), StartTls(JabberStream), ConnectionEstablished(Tls), StreamStarted(JabberStream), GotFeatures((Features, JabberStream)), Sasl(Mechanisms, JabberStream), Bind(JabberStream), } impl Connecting { pub async fn start(server: &str) -> Result { match Connection::connect(server).await? { Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)), Connection::Unencrypted(tcp_stream) => { Ok(Connecting::InsecureConnectionEstablised(tcp_stream)) } } } } pub enum InsecureConnecting { Disconnected, ConnectionEstablished(Connection), PreStarttls(JabberStream), PreAuthenticated(JabberStream), Authenticated(Tls), PreBound(JabberStream), Bound(JabberStream), } #[cfg(test)] mod tests { use std::time::Duration; use super::JabberClient; use test_log::test; use tokio::time::sleep; #[test(tokio::test)] async fn login() { let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await } }