implement send_stanza

This commit is contained in:
cel 🌸 2024-12-04 17:38:36 +00:00
parent 4886396044
commit 21f10a0b43
6 changed files with 39 additions and 118 deletions

10
README.md Normal file
View File

@ -0,0 +1,10 @@
# jabber client library
## TODO:
- [ ] error states for all negotiation parts
- [ ] better errors
- [ ] rename structs
- [ ] remove commented code
- [ ] asynchronous connect (with take_mut?)
- [ ] split into separate crates: stanza, jabber, and luz

View File

@ -44,6 +44,8 @@ impl JabberClient {
pub async fn connect(&mut self) -> Result<()> { pub async fn connect(&mut self) -> Result<()> {
match &self.connection { match &self.connection {
ConnectionState::Disconnected => { ConnectionState::Disconnected => {
// TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
// perhaps use take_mut?
self.connection = ConnectionState::Disconnected self.connection = ConnectionState::Disconnected
.connect(&mut self.jid, self.password.clone(), &mut self.server) .connect(&mut self.jid, self.password.clone(), &mut self.server)
.await?; .await?;
@ -53,6 +55,16 @@ impl JabberClient {
ConnectionState::Connected(_jabber_stream) => Ok(()), ConnectionState::Connected(_jabber_stream) => Ok(()),
} }
} }
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
match &mut self.connection {
ConnectionState::Disconnected => return Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
ConnectionState::Connected(jabber_stream) => {
Ok(jabber_stream.send_stanza(stanza).await?)
}
}
}
} }
impl Stream for JabberClient { impl Stream for JabberClient {
@ -217,38 +229,6 @@ pub enum InsecureConnecting {
Bound(JabberStream<Tls>), Bound(JabberStream<Tls>),
} }
impl Sink<Stanza> for JabberClient {
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: Stanza,
) -> std::result::Result<(), Self::Error> {
todo!()
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration; use std::time::Duration;

View File

@ -3,6 +3,7 @@ use std::str::Utf8Error;
use rsasl::mechname::MechanismNameError; use rsasl::mechname::MechanismNameError;
use crate::stanza::client::error::Error as ClientError; use crate::stanza::client::error::Error as ClientError;
use crate::stanza::stream::Error as StreamError;
use crate::{jid::ParseError, stanza::sasl::Failure}; use crate::{jid::ParseError, stanza::sasl::Failure};
#[derive(Debug)] #[derive(Debug)]
@ -22,7 +23,10 @@ pub enum Error {
JID(ParseError), JID(ParseError),
Authentication(Failure), Authentication(Failure),
ClientError(ClientError), ClientError(ClientError),
StreamError(StreamError),
MissingError, MissingError,
Disconnected,
Connecting,
} }
#[derive(Debug)] #[derive(Debug)]

View File

@ -50,6 +50,7 @@ where
S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
JabberStream<S>: std::fmt::Debug, JabberStream<S>: std::fmt::Debug,
{ {
#[instrument]
pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> { pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> {
let sasl = SASLClient::new(sasl_config); let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new(); let mut offered_mechs: Vec<&Mechname> = Vec::new();
@ -139,6 +140,7 @@ where
Ok(stream) Ok(stream)
} }
#[instrument]
pub async fn bind(mut self, jid: &mut JID) -> Result<Self> { pub async fn bind(mut self, jid: &mut JID) -> Result<Self> {
let iq_id = nanoid::nanoid!(); let iq_id = nanoid::nanoid!();
if let Some(resource) = &jid.resourcepart { if let Some(resource) = &jid.resourcepart {
@ -266,6 +268,7 @@ where
Ok(Self { reader, writer }) Ok(Self { reader, writer })
} }
#[instrument]
pub async fn get_features(mut self) -> Result<(Features, Self)> { pub async fn get_features(mut self) -> Result<(Features, Self)> {
debug!("getting features"); debug!("getting features");
let features: Features = self.reader.read().await?; let features: Features = self.reader.read().await?;
@ -276,91 +279,16 @@ where
pub fn into_inner(self) -> S { pub fn into_inner(self) -> S {
self.reader.into_inner().unsplit(self.writer.into_inner()) self.reader.into_inner().unsplit(self.writer.into_inner())
} }
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
self.writer.write(stanza).await?;
Ok(())
}
} }
impl JabberStream<Unencrypted> { impl JabberStream<Unencrypted> {
// pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>( #[instrument]
// mut self, pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> {
// features: Features,
// ) -> Result<Feature> {
// // TODO: timeout
// if let Some(Feature::StartTls(_)) = features
// .features
// .iter()
// .find(|feature| matches!(feature, Feature::StartTls(_s)))
// {
// return Ok(self);
// } else {
// // TODO: better error
// return Err(Error::TlsRequired);
// }
// }
// #[async_recursion]
// pub async fn negotiate_tls_optional(mut self) -> Result<Connection> {
// self.start_stream().await?;
// // TODO: timeout
// let features = self.get_features().await?.features;
// if let Some(Feature::StartTls(_)) = features
// .iter()
// .find(|feature| matches!(feature, Feature::StartTls(_s)))
// {
// let jabber = self.starttls().await?;
// let jabber = jabber.negotiate().await?;
// return Ok(Connection::Encrypted(jabber));
// } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
// self.auth.clone(),
// features
// .iter()
// .find(|feature| matches!(feature, Feature::Sasl(_))),
// ) {
// self.sasl(mechanisms.clone(), sasl_config).await?;
// let jabber = self.negotiate_tls_optional().await?;
// Ok(jabber)
// } else if let Some(Feature::Bind) = features
// .iter()
// .find(|feature| matches!(feature, Feature::Bind))
// {
// self.bind().await?;
// Ok(Connection::Unencrypted(self))
// } else {
// // TODO: better error
// return Err(Error::Negotiation);
// }
// }
}
impl JabberStream<Tls> {
// #[async_recursion]
// pub async fn negotiate(mut self) -> Result<JabberStream<Tls>> {
// self.start_stream().await?;
// let features = self.get_features().await?.features;
// if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
// self.auth.clone(),
// features
// .iter()
// .find(|feature| matches!(feature, Feature::Sasl(_))),
// ) {
// // TODO: avoid clone
// self.sasl(mechanisms.clone(), sasl_config).await?;
// let jabber = self.negotiate().await?;
// Ok(jabber)
// } else if let Some(Feature::Bind) = features
// .iter()
// .find(|feature| matches!(feature, Feature::Bind))
// {
// self.bind().await?;
// Ok(self)
// } else {
// // TODO: better error
// return Err(Error::Negotiation);
// }
// }
}
impl JabberStream<Unencrypted> {
pub async fn starttls(mut self, domain: impl AsRef<str>) -> Result<Tls> {
self.writer self.writer
.write_full(&StartTls { required: false }) .write_full(&StartTls { required: false })
.await?; .await?;
@ -372,8 +300,6 @@ impl JabberStream<Unencrypted> {
.connect(domain.as_ref(), stream) .connect(domain.as_ref(), stream)
.await .await
{ {
// let (read, write) = tokio::io::split(tlsstream);
// let client = JabberStream::new(read, write);
return Ok(tls_stream); return Ok(tls_stream);
} else { } else {
return Err(Error::Connection); return Err(Error::Connection);

View File

@ -164,6 +164,7 @@ impl FromElement for Feature {
} }
} }
#[derive(Debug)]
pub struct Error { pub struct Error {
error: StreamError, error: StreamError,
text: Option<Text>, text: Option<Text>,

View File

@ -5,7 +5,7 @@ use peanuts::{
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams"; pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams";
#[derive(Clone)] #[derive(Clone, Debug)]
pub enum Error { pub enum Error {
BadFormat, BadFormat,
BadNamespacePrefix, BadNamespacePrefix,
@ -110,7 +110,7 @@ impl IntoElement for Error {
} }
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct Text { pub struct Text {
text: Option<String>, text: Option<String>,
lang: Option<String>, lang: Option<String>,