luz/src/client.rs

181 lines
5.8 KiB
Rust
Raw Normal View History

2024-12-03 23:57:04 +00:00
use std::sync::Arc;
use futures::{Sink, Stream};
use rsasl::config::SASLConfig;
use crate::{
connection::{Tls, Unencrypted},
stanza::{
client::Stanza,
sasl::Mechanisms,
stream::{Feature, Features},
},
Connection, Error, JabberStream, Result, JID,
};
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
connection: JabberState,
jid: JID,
password: Arc<SASLConfig>,
server: String,
}
pub enum JabberState {
Disconnected,
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
StartTls(JabberStream<Unencrypted>),
ConnectionEstablished(Tls),
StreamStarted(JabberStream<Tls>),
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
// when it's bound, can stream stanzas and sink stanzas
Bound(JabberStream<Tls>),
}
impl JabberState {
pub async fn advance_state(
self,
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<JabberState> {
match self {
JabberState::Disconnected => match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => {
Ok(JabberState::ConnectionEstablished(tls_stream))
}
Connection::Unencrypted(tcp_stream) => {
Ok(JabberState::InsecureConnectionEstablised(tcp_stream))
}
},
JabberState::InsecureConnectionEstablised(tcp_stream) => Ok({
JabberState::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
)
}),
JabberState::InsecureStreamStarted(jabber_stream) => Ok(
JabberState::InsecureGotFeatures(jabber_stream.get_features().await?),
),
JabberState::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => Ok(JabberState::StartTls(jabber_stream)),
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
JabberState::StartTls(jabber_stream) => Ok(JabberState::ConnectionEstablished(
jabber_stream.starttls(server).await?,
)),
JabberState::ConnectionEstablished(tls_stream) => Ok(JabberState::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
)),
JabberState::StreamStarted(jabber_stream) => Ok(JabberState::GotFeatures(
jabber_stream.get_features().await?,
)),
JabberState::GotFeatures((features, jabber_stream)) => match features.negotiate()? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
return Ok(JabberState::Sasl(mechanisms, jabber_stream))
}
Feature::Bind => return Ok(JabberState::Bind(jabber_stream)),
Feature::Unknown => return Err(Error::Unsupported),
},
JabberState::Sasl(mechanisms, jabber_stream) => {
return Ok(JabberState::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth).await?,
))
}
JabberState::Bind(jabber_stream) => {
Ok(JabberState::Bound(jabber_stream.bind(jid).await?))
}
JabberState::Bound(jabber_stream) => Ok(JabberState::Bound(jabber_stream)),
}
}
}
impl Features {
pub fn negotiate(self) -> Result<Feature> {
if let Some(Feature::StartTls(s)) = self
.features
.iter()
.find(|feature| matches!(feature, Feature::StartTls(_s)))
{
// TODO: avoid clone
return Ok(Feature::StartTls(s.clone()));
} else if let Some(Feature::Sasl(mechanisms)) = self
.features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_)))
{
// TODO: avoid clone
return Ok(Feature::Sasl(mechanisms.clone()));
} else if let Some(Feature::Bind) = self
.features
.into_iter()
.find(|feature| matches!(feature, Feature::Bind))
{
Ok(Feature::Bind)
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
}
pub enum InsecureJabberConnection {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
PreAuthenticated(JabberStream<Tls>),
Authenticated(Tls),
PreBound(JabberStream<Tls>),
Bound(JabberStream<Tls>),
}
impl Stream for JabberClient {
type Item = Stanza;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
todo!()
}
}
impl Sink<Stanza> for JabberClient {
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: Stanza,
) -> std::result::Result<(), Self::Error> {
todo!()
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
}