luz/jabber/src/client.rs

218 lines
7.9 KiB
Rust

use std::{pin::pin, sync::Arc, task::Poll};
use futures::{Sink, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig;
use stanza::{
client::Stanza,
sasl::Mechanisms,
stream::{Feature, Features},
};
use crate::{
connection::{Tls, Unencrypted},
Connection, Error, JabberStream, Result, JID,
};
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
connection: ConnectionState,
jid: JID,
password: Arc<SASLConfig>,
server: String,
}
impl JabberClient {
pub fn new(
jid: impl TryInto<JID, Error = ParseError>,
password: impl ToString,
) -> Result<JabberClient> {
let jid = jid.try_into()?;
let sasl_config = SASLConfig::with_credentials(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.to_string(),
)?;
Ok(JabberClient {
connection: ConnectionState::Disconnected,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
})
}
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
ConnectionState::Disconnected => {
// TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
// perhaps use take_mut?
self.connection = ConnectionState::Disconnected
.connect(&mut self.jid, self.password.clone(), &mut self.server)
.await?;
Ok(())
}
ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
ConnectionState::Connected(_jabber_stream) => Ok(()),
}
}
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
match &mut self.connection {
ConnectionState::Disconnected => return Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
ConnectionState::Connected(jabber_stream) => {
Ok(jabber_stream.send_stanza(stanza).await?)
}
}
}
}
impl Stream for JabberClient {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut client = pin!(self);
match &mut client.connection {
ConnectionState::Disconnected => Poll::Pending,
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
}
}
}
pub enum ConnectionState {
Disconnected,
Connecting(Connecting),
Connected(JabberStream<Tls>),
}
impl ConnectionState {
pub async fn connect(
mut self,
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<Self> {
loop {
match self {
ConnectionState::Disconnected => {
self = ConnectionState::Connecting(Connecting::start(&server).await?);
}
ConnectionState::Connecting(connecting) => match connecting {
Connecting::InsecureConnectionEstablised(tcp_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
))
}
Connecting::InsecureStreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => {
self =
ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
}
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
Connecting::StartTls(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.starttls(&server).await?,
))
}
Connecting::ConnectionEstablished(tls_stream) => {
self = ConnectionState::Connecting(Connecting::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
))
}
Connecting::StreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::GotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::GotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
self = ConnectionState::Connecting(Connecting::Sasl(
mechanisms,
jabber_stream,
))
}
Feature::Bind => {
self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
}
Feature::Unknown => return Err(Error::Unsupported),
}
}
Connecting::Sasl(mechanisms, jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth.clone()).await?,
))
}
Connecting::Bind(jabber_stream) => {
self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
}
},
connected => return Ok(connected),
}
}
}
}
pub enum Connecting {
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
StartTls(JabberStream<Unencrypted>),
ConnectionEstablished(Tls),
StreamStarted(JabberStream<Tls>),
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
}
impl Connecting {
pub async fn start(server: &str) -> Result<Self> {
match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
Connection::Unencrypted(tcp_stream) => {
Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
}
}
}
}
pub enum InsecureConnecting {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
PreAuthenticated(JabberStream<Tls>),
Authenticated(Tls),
PreBound(JabberStream<Tls>),
Bound(JabberStream<Tls>),
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::JabberClient;
use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)]
async fn login() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await
}
}