use std::{ borrow::Borrow, future::Future, pin::pin, sync::Arc, task::{ready, Poll}, }; use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; use jid::ParseError; use rsasl::config::SASLConfig; use stanza::{ client::Stanza, sasl::Mechanisms, stream::{Feature, Features}, }; use tokio::sync::Mutex; use crate::{ connection::{Tls, Unencrypted}, jabber_stream::bound_stream::BoundJabberStream, Connection, Error, JabberStream, Result, JID, }; // feed it client stanzas, receive client stanzas pub struct JabberClient { connection: ConnectionState, jid: JID, password: Arc, server: String, } impl JabberClient { pub fn new( jid: impl TryInto, password: impl ToString, ) -> Result { let jid = jid.try_into()?; let sasl_config = SASLConfig::with_credentials( None, jid.localpart.clone().ok_or(Error::NoLocalpart)?, password.to_string(), )?; Ok(JabberClient { connection: ConnectionState::Disconnected, jid: jid.clone(), password: sasl_config, server: jid.domainpart, }) } pub async fn connect(&mut self) -> Result<()> { match &self.connection { ConnectionState::Disconnected => { // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?) // perhaps use take_mut? self.connection = ConnectionState::Disconnected .connect(&mut self.jid, self.password.clone(), &mut self.server) .await?; Ok(()) } ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting), ConnectionState::Connected(_jabber_stream) => Ok(()), } } pub(crate) fn inner(self) -> Result> { match self.connection { ConnectionState::Disconnected => return Err(Error::Disconnected), ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream), } } // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { // match &mut self.connection { // ConnectionState::Disconnected => return Err(Error::Disconnected), // ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), // ConnectionState::Connected(jabber_stream) => { // Ok(jabber_stream.send_stanza(stanza).await?) // } // } // } } impl Sink for JabberClient { type Error = Error; fn poll_ready( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { self.get_mut().connection.poll_ready_unpin(cx) } fn start_send( self: std::pin::Pin<&mut Self>, item: Stanza, ) -> std::result::Result<(), Self::Error> { self.get_mut().connection.start_send_unpin(item) } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { self.get_mut().connection.poll_flush_unpin(cx) } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { self.get_mut().connection.poll_flush_unpin(cx) } } impl Stream for JabberClient { type Item = Result; fn poll_next( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { self.get_mut().connection.poll_next_unpin(cx) } } pub enum ConnectionState { Disconnected, Connecting(Connecting), Connected(BoundJabberStream), } impl Sink for ConnectionState { type Error = Error; fn poll_ready( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { match self.get_mut() { ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), ConnectionState::Connecting(_connecting) => Poll::Pending, ConnectionState::Connected(bound_jabber_stream) => { bound_jabber_stream.poll_ready_unpin(cx) } } } fn start_send( self: std::pin::Pin<&mut Self>, item: Stanza, ) -> std::result::Result<(), Self::Error> { match self.get_mut() { ConnectionState::Disconnected => Err(Error::Disconnected), ConnectionState::Connecting(_connecting) => Err(Error::Connecting), ConnectionState::Connected(bound_jabber_stream) => { bound_jabber_stream.start_send_unpin(item) } } } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { match self.get_mut() { ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), ConnectionState::Connecting(_connecting) => Poll::Pending, ConnectionState::Connected(bound_jabber_stream) => { bound_jabber_stream.poll_flush_unpin(cx) } } } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { match self.get_mut() { ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), ConnectionState::Connecting(_connecting) => Poll::Pending, ConnectionState::Connected(bound_jabber_stream) => { bound_jabber_stream.poll_close_unpin(cx) } } } } impl Stream for ConnectionState { type Item = Result; fn poll_next( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { match self.get_mut() { ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))), ConnectionState::Connecting(_connecting) => Poll::Pending, ConnectionState::Connected(bound_jabber_stream) => { bound_jabber_stream.poll_next_unpin(cx) } } } } impl ConnectionState { pub async fn connect( mut self, jid: &mut JID, auth: Arc, server: &mut String, ) -> Result { loop { match self { ConnectionState::Disconnected => { self = ConnectionState::Connecting(Connecting::start(&server).await?); } ConnectionState::Connecting(connecting) => match connecting { Connecting::InsecureConnectionEstablised(tcp_stream) => { self = ConnectionState::Connecting(Connecting::InsecureStreamStarted( JabberStream::start_stream(tcp_stream, server).await?, )) } Connecting::InsecureStreamStarted(jabber_stream) => { self = ConnectionState::Connecting(Connecting::InsecureGotFeatures( jabber_stream.get_features().await?, )) } Connecting::InsecureGotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => { self = ConnectionState::Connecting(Connecting::StartTls(jabber_stream)) } // TODO: better error _ => return Err(Error::TlsRequired), } } Connecting::StartTls(jabber_stream) => { self = ConnectionState::Connecting(Connecting::ConnectionEstablished( jabber_stream.starttls(&server).await?, )) } Connecting::ConnectionEstablished(tls_stream) => { self = ConnectionState::Connecting(Connecting::StreamStarted( JabberStream::start_stream(tls_stream, server).await?, )) } Connecting::StreamStarted(jabber_stream) => { self = ConnectionState::Connecting(Connecting::GotFeatures( jabber_stream.get_features().await?, )) } Connecting::GotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), Feature::Sasl(mechanisms) => { self = ConnectionState::Connecting(Connecting::Sasl( mechanisms, jabber_stream, )) } Feature::Bind => { self = ConnectionState::Connecting(Connecting::Bind(jabber_stream)) } Feature::Unknown => return Err(Error::Unsupported), } } Connecting::Sasl(mechanisms, jabber_stream) => { self = ConnectionState::Connecting(Connecting::ConnectionEstablished( jabber_stream.sasl(mechanisms, auth.clone()).await?, )) } Connecting::Bind(jabber_stream) => { self = ConnectionState::Connected( jabber_stream.bind(jid).await?.to_bound_jabber(), ) } }, connected => return Ok(connected), } } } } pub enum Connecting { InsecureConnectionEstablised(Unencrypted), InsecureStreamStarted(JabberStream), InsecureGotFeatures((Features, JabberStream)), StartTls(JabberStream), ConnectionEstablished(Tls), StreamStarted(JabberStream), GotFeatures((Features, JabberStream)), Sasl(Mechanisms, JabberStream), Bind(JabberStream), } impl Connecting { pub async fn start(server: &str) -> Result { match Connection::connect(server).await? { Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)), Connection::Unencrypted(tcp_stream) => { Ok(Connecting::InsecureConnectionEstablised(tcp_stream)) } } } } pub enum InsecureConnecting { Disconnected, ConnectionEstablished(Connection), PreStarttls(JabberStream), PreAuthenticated(JabberStream), Authenticated(Tls), PreBound(JabberStream), Bound(JabberStream), } #[cfg(test)] mod tests { use std::{sync::Arc, time::Duration}; use super::JabberClient; use futures::{SinkExt, StreamExt}; use stanza::{ client::{ iq::{Iq, IqType, Query}, Stanza, }, xep_0199::Ping, }; use test_log::test; use tokio::{sync::Mutex, time::sleep}; use tracing::info; #[test(tokio::test)] async fn login() { let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await } #[test(tokio::test)] async fn ping_parallel() { let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await; let jid = client.jid.clone(); let server = client.server.clone(); let (mut write, mut read) = client.split(); tokio::join!( async { write .send(Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s1".to_string(), to: Some(server.clone().try_into().unwrap()), r#type: IqType::Get, lang: None, query: Some(Query::Ping(Ping)), errors: Vec::new(), })) .await; write .send(Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s2".to_string(), to: Some(server.clone().try_into().unwrap()), r#type: IqType::Get, lang: None, query: Some(Query::Ping(Ping)), errors: Vec::new(), })) .await; }, async { while let Some(stanza) = read.next().await { info!("{:#?}", stanza); } } ); } }