2024-12-04 02:09:07 +00:00
|
|
|
use std::{pin::pin, sync::Arc, task::Poll};
|
2024-12-03 23:57:04 +00:00
|
|
|
|
2024-12-04 02:09:07 +00:00
|
|
|
use futures::{Sink, Stream, StreamExt};
|
2024-12-04 18:18:37 +00:00
|
|
|
use jid::ParseError;
|
2024-12-03 23:57:04 +00:00
|
|
|
use rsasl::config::SASLConfig;
|
2024-12-04 18:18:37 +00:00
|
|
|
use stanza::{
|
|
|
|
client::Stanza,
|
|
|
|
sasl::Mechanisms,
|
|
|
|
stream::{Feature, Features},
|
|
|
|
};
|
2024-12-03 23:57:04 +00:00
|
|
|
|
|
|
|
use crate::{
|
|
|
|
connection::{Tls, Unencrypted},
|
|
|
|
Connection, Error, JabberStream, Result, JID,
|
|
|
|
};
|
|
|
|
|
|
|
|
// feed it client stanzas, receive client stanzas
|
|
|
|
pub struct JabberClient {
|
2024-12-04 02:09:07 +00:00
|
|
|
connection: ConnectionState,
|
2024-12-03 23:57:04 +00:00
|
|
|
jid: JID,
|
|
|
|
password: Arc<SASLConfig>,
|
|
|
|
server: String,
|
|
|
|
}
|
|
|
|
|
2024-12-04 02:09:07 +00:00
|
|
|
impl JabberClient {
|
|
|
|
pub fn new(
|
|
|
|
jid: impl TryInto<JID, Error = ParseError>,
|
|
|
|
password: impl ToString,
|
|
|
|
) -> Result<JabberClient> {
|
|
|
|
let jid = jid.try_into()?;
|
|
|
|
let sasl_config = SASLConfig::with_credentials(
|
|
|
|
None,
|
|
|
|
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
|
|
|
|
password.to_string(),
|
|
|
|
)?;
|
|
|
|
Ok(JabberClient {
|
|
|
|
connection: ConnectionState::Disconnected,
|
|
|
|
jid: jid.clone(),
|
|
|
|
password: sasl_config,
|
|
|
|
server: jid.domainpart,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn connect(&mut self) -> Result<()> {
|
|
|
|
match &self.connection {
|
|
|
|
ConnectionState::Disconnected => {
|
2024-12-04 17:38:36 +00:00
|
|
|
// TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
|
|
|
|
// perhaps use take_mut?
|
2024-12-04 02:09:07 +00:00
|
|
|
self.connection = ConnectionState::Disconnected
|
|
|
|
.connect(&mut self.jid, self.password.clone(), &mut self.server)
|
|
|
|
.await?;
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
|
|
|
|
ConnectionState::Connected(_jabber_stream) => Ok(()),
|
|
|
|
}
|
|
|
|
}
|
2024-12-04 17:38:36 +00:00
|
|
|
|
|
|
|
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
|
|
|
|
match &mut self.connection {
|
|
|
|
ConnectionState::Disconnected => return Err(Error::Disconnected),
|
|
|
|
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
|
|
|
|
ConnectionState::Connected(jabber_stream) => {
|
|
|
|
Ok(jabber_stream.send_stanza(stanza).await?)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-12-04 02:09:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Stream for JabberClient {
|
|
|
|
type Item = Result<Stanza>;
|
|
|
|
|
|
|
|
fn poll_next(
|
|
|
|
self: std::pin::Pin<&mut Self>,
|
|
|
|
cx: &mut std::task::Context<'_>,
|
|
|
|
) -> std::task::Poll<Option<Self::Item>> {
|
|
|
|
let mut client = pin!(self);
|
|
|
|
match &mut client.connection {
|
|
|
|
ConnectionState::Disconnected => Poll::Pending,
|
|
|
|
ConnectionState::Connecting(_connecting) => Poll::Pending,
|
|
|
|
ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub enum ConnectionState {
|
2024-12-03 23:57:04 +00:00
|
|
|
Disconnected,
|
2024-12-04 02:09:07 +00:00
|
|
|
Connecting(Connecting),
|
|
|
|
Connected(JabberStream<Tls>),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ConnectionState {
|
|
|
|
pub async fn connect(
|
|
|
|
mut self,
|
|
|
|
jid: &mut JID,
|
|
|
|
auth: Arc<SASLConfig>,
|
|
|
|
server: &mut String,
|
|
|
|
) -> Result<Self> {
|
|
|
|
loop {
|
|
|
|
match self {
|
|
|
|
ConnectionState::Disconnected => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::start(&server).await?);
|
|
|
|
}
|
|
|
|
ConnectionState::Connecting(connecting) => match connecting {
|
|
|
|
Connecting::InsecureConnectionEstablised(tcp_stream) => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
|
|
|
|
JabberStream::start_stream(tcp_stream, server).await?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
Connecting::InsecureStreamStarted(jabber_stream) => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
|
|
|
|
jabber_stream.get_features().await?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
|
2024-12-04 18:18:37 +00:00
|
|
|
match features.negotiate().ok_or(Error::Negotiation)? {
|
2024-12-04 02:09:07 +00:00
|
|
|
Feature::StartTls(_start_tls) => {
|
|
|
|
self =
|
|
|
|
ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
|
|
|
|
}
|
|
|
|
// TODO: better error
|
|
|
|
_ => return Err(Error::TlsRequired),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Connecting::StartTls(jabber_stream) => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
|
|
|
|
jabber_stream.starttls(&server).await?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
Connecting::ConnectionEstablished(tls_stream) => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::StreamStarted(
|
|
|
|
JabberStream::start_stream(tls_stream, server).await?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
Connecting::StreamStarted(jabber_stream) => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::GotFeatures(
|
|
|
|
jabber_stream.get_features().await?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
Connecting::GotFeatures((features, jabber_stream)) => {
|
2024-12-04 18:18:37 +00:00
|
|
|
match features.negotiate().ok_or(Error::Negotiation)? {
|
2024-12-04 02:09:07 +00:00
|
|
|
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
|
|
|
|
Feature::Sasl(mechanisms) => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::Sasl(
|
|
|
|
mechanisms,
|
|
|
|
jabber_stream,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
Feature::Bind => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
|
|
|
|
}
|
|
|
|
Feature::Unknown => return Err(Error::Unsupported),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Connecting::Sasl(mechanisms, jabber_stream) => {
|
|
|
|
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
|
|
|
|
jabber_stream.sasl(mechanisms, auth.clone()).await?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
Connecting::Bind(jabber_stream) => {
|
|
|
|
self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
|
|
|
|
}
|
|
|
|
},
|
|
|
|
connected => return Ok(connected),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub enum Connecting {
|
2024-12-03 23:57:04 +00:00
|
|
|
InsecureConnectionEstablised(Unencrypted),
|
|
|
|
InsecureStreamStarted(JabberStream<Unencrypted>),
|
|
|
|
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
|
|
|
|
StartTls(JabberStream<Unencrypted>),
|
|
|
|
ConnectionEstablished(Tls),
|
|
|
|
StreamStarted(JabberStream<Tls>),
|
|
|
|
GotFeatures((Features, JabberStream<Tls>)),
|
|
|
|
Sasl(Mechanisms, JabberStream<Tls>),
|
|
|
|
Bind(JabberStream<Tls>),
|
|
|
|
}
|
|
|
|
|
2024-12-04 02:09:07 +00:00
|
|
|
impl Connecting {
|
|
|
|
pub async fn start(server: &str) -> Result<Self> {
|
|
|
|
match Connection::connect(server).await? {
|
|
|
|
Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
|
|
|
|
Connection::Unencrypted(tcp_stream) => {
|
|
|
|
Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
|
2024-12-03 23:57:04 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-12-04 02:09:07 +00:00
|
|
|
pub enum InsecureConnecting {
|
2024-12-03 23:57:04 +00:00
|
|
|
Disconnected,
|
|
|
|
ConnectionEstablished(Connection),
|
|
|
|
PreStarttls(JabberStream<Unencrypted>),
|
|
|
|
PreAuthenticated(JabberStream<Tls>),
|
|
|
|
Authenticated(Tls),
|
|
|
|
PreBound(JabberStream<Tls>),
|
|
|
|
Bound(JabberStream<Tls>),
|
|
|
|
}
|
|
|
|
|
2024-12-04 02:09:07 +00:00
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use std::time::Duration;
|
|
|
|
|
|
|
|
use super::JabberClient;
|
|
|
|
use test_log::test;
|
|
|
|
use tokio::time::sleep;
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn login() {
|
|
|
|
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
|
|
|
|
client.connect().await.unwrap();
|
|
|
|
sleep(Duration::from_secs(5)).await
|
|
|
|
}
|
|
|
|
}
|