implement Error for jabber crate error types

This commit is contained in:
cel 🌸 2025-02-25 20:31:10 +00:00
parent 53ea2951ae
commit d30131e0fc
4 changed files with 45 additions and 74 deletions

View File

@ -30,6 +30,7 @@ futures = "0.3.31"
take_mut = "0.2.2" take_mut = "0.2.2"
pin-project-lite = "0.2.15" pin-project-lite = "0.2.15"
pin-project = "1.1.7" pin-project = "1.1.7"
thiserror = "2.0.11"
[dev-dependencies] [dev-dependencies]
test-log = { version = "0.2", features = ["trace"] } test-log = { version = "0.2", features = ["trace"] }

View File

@ -19,7 +19,8 @@ pub async fn connect_and_login(
None, None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?, jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.as_ref().to_string(), password.as_ref().to_string(),
)?; )
.map_err(|e| Error::SASL(e.into()))?;
let mut conn_state = Connecting::start(&server).await?; let mut conn_state = Connecting::start(&server).await?;
loop { loop {
match conn_state { match conn_state {
@ -108,9 +109,8 @@ pub enum InsecureConnecting {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{sync::Arc, time::Duration}; use std::time::Duration;
use futures::{SinkExt, StreamExt};
use jid::JID; use jid::JID;
use stanza::{ use stanza::{
client::{ client::{
@ -120,7 +120,7 @@ mod tests {
xep_0199::Ping, xep_0199::Ping,
}; };
use test_log::test; use test_log::test;
use tokio::{sync::Mutex, time::sleep}; use tokio::time::sleep;
use tracing::info; use tracing::info;
use super::connect_and_login; use super::connect_and_login;
@ -128,7 +128,7 @@ mod tests {
#[test(tokio::test)] #[test(tokio::test)]
async fn login() { async fn login() {
let mut jid: JID = "test@blos.sm".try_into().unwrap(); let mut jid: JID = "test@blos.sm".try_into().unwrap();
let client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string()) let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
.await .await
.unwrap(); .unwrap();
sleep(Duration::from_secs(5)).await sleep(Duration::from_secs(5)).await

View File

@ -5,83 +5,50 @@ use rsasl::mechname::MechanismNameError;
use stanza::client::error::Error as ClientError; use stanza::client::error::Error as ClientError;
use stanza::sasl::Failure; use stanza::sasl::Failure;
use stanza::stream::Error as StreamError; use stanza::stream::Error as StreamError;
use thiserror::Error;
use tokio::task::JoinError; use tokio::task::JoinError;
#[derive(Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
#[error("connection")]
Connection, Connection,
Utf8Decode, #[error("utf8 decode: {0}")]
Utf8Decode(#[from] Utf8Error),
#[error("negotiation")]
Negotiation, Negotiation,
#[error("tls required")]
TlsRequired, TlsRequired,
#[error("already connected with tls")]
AlreadyTls, AlreadyTls,
// TODO: specify unsupported feature
#[error("unsupported feature")]
Unsupported, Unsupported,
#[error("jid missing localpart")]
NoLocalpart, NoLocalpart,
AlreadyConnecting, #[error("received unexpected element: {0:?}")]
StreamClosed,
UnexpectedElement(peanuts::Element), UnexpectedElement(peanuts::Element),
XML(peanuts::Error), #[error("xml error: {0}")]
Deserialization(peanuts::DeserializeError), XML(#[from] peanuts::Error),
SASL(SASLError), #[error("sasl error: {0}")]
JID(ParseError), SASL(#[from] SASLError),
Authentication(Failure), #[error("jid error: {0}")]
ClientError(ClientError), JID(#[from] ParseError),
StreamError(StreamError), #[error("client stanza error: {0}")]
ClientError(#[from] ClientError),
#[error("stream error: {0}")]
StreamError(#[from] StreamError),
#[error("error missing")]
MissingError, MissingError,
Disconnected, #[error("task join error")]
Connecting, JoinError(#[from] JoinError),
JoinError(JoinError),
} }
#[derive(Debug)] #[derive(Error, Debug)]
pub enum SASLError { pub enum SASLError {
SASL(rsasl::prelude::SASLError), #[error("sasl error: {0}")]
MechanismName(MechanismNameError), SASL(#[from] rsasl::prelude::SASLError),
} #[error("mechanism error: {0}")]
MechanismName(#[from] MechanismNameError),
impl From<rsasl::prelude::SASLError> for Error { #[error("authentication failure: {0}")]
fn from(e: rsasl::prelude::SASLError) -> Self { Authentication(#[from] Failure),
Self::SASL(SASLError::SASL(e))
}
}
impl From<JoinError> for Error {
fn from(e: JoinError) -> Self {
Self::JoinError(e)
}
}
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)
}
}
impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e))
}
}
impl From<SASLError> for Error {
fn from(e: SASLError) -> Self {
Self::SASL(e)
}
}
impl From<Utf8Error> for Error {
fn from(_e: Utf8Error) -> Self {
Self::Utf8Decode
}
}
impl From<peanuts::Error> for Error {
fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
impl From<ParseError> for Error {
fn from(e: ParseError) -> Self {
Self::JID(e)
}
} }

View File

@ -133,10 +133,13 @@ where
let sasl = SASLClient::new(sasl_config); let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new(); let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms { for mechanism in &mechanisms.mechanisms {
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) offered_mechs
.push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)
} }
debug!("{:?}", offered_mechs); debug!("{:?}", offered_mechs);
let mut session = sasl.start_suggested(&offered_mechs)?; let mut session = sasl
.start_suggested(&offered_mechs)
.map_err(|e| Error::SASL(e.into()))?;
let selected_mechanism = session.get_mechname().as_str().to_owned(); let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism); debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>>; let mut data: Option<Vec<u8>>;
@ -174,7 +177,7 @@ where
ServerResponse::Success(success) => { ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec()) data = success.clone().map(|success| success.as_bytes().to_vec())
} }
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
} }
debug!("we went first"); debug!("we went first");
} }
@ -205,7 +208,7 @@ where
ServerResponse::Success(success) => { ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec()) data = success.clone().map(|success| success.as_bytes().to_vec())
} }
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
} }
} }
} }