implement Error for jabber crate error types
This commit is contained in:
parent
53ea2951ae
commit
d30131e0fc
|
@ -30,6 +30,7 @@ futures = "0.3.31"
|
|||
take_mut = "0.2.2"
|
||||
pin-project-lite = "0.2.15"
|
||||
pin-project = "1.1.7"
|
||||
thiserror = "2.0.11"
|
||||
|
||||
[dev-dependencies]
|
||||
test-log = { version = "0.2", features = ["trace"] }
|
||||
|
|
|
@ -19,7 +19,8 @@ pub async fn connect_and_login(
|
|||
None,
|
||||
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
|
||||
password.as_ref().to_string(),
|
||||
)?;
|
||||
)
|
||||
.map_err(|e| Error::SASL(e.into()))?;
|
||||
let mut conn_state = Connecting::start(&server).await?;
|
||||
loop {
|
||||
match conn_state {
|
||||
|
@ -108,9 +109,8 @@ pub enum InsecureConnecting {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use jid::JID;
|
||||
use stanza::{
|
||||
client::{
|
||||
|
@ -120,7 +120,7 @@ mod tests {
|
|||
xep_0199::Ping,
|
||||
};
|
||||
use test_log::test;
|
||||
use tokio::{sync::Mutex, time::sleep};
|
||||
use tokio::time::sleep;
|
||||
use tracing::info;
|
||||
|
||||
use super::connect_and_login;
|
||||
|
@ -128,7 +128,7 @@ mod tests {
|
|||
#[test(tokio::test)]
|
||||
async fn login() {
|
||||
let mut jid: JID = "test@blos.sm".try_into().unwrap();
|
||||
let client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
|
||||
let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
sleep(Duration::from_secs(5)).await
|
||||
|
|
|
@ -5,83 +5,50 @@ use rsasl::mechname::MechanismNameError;
|
|||
use stanza::client::error::Error as ClientError;
|
||||
use stanza::sasl::Failure;
|
||||
use stanza::stream::Error as StreamError;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("connection")]
|
||||
Connection,
|
||||
Utf8Decode,
|
||||
#[error("utf8 decode: {0}")]
|
||||
Utf8Decode(#[from] Utf8Error),
|
||||
#[error("negotiation")]
|
||||
Negotiation,
|
||||
#[error("tls required")]
|
||||
TlsRequired,
|
||||
#[error("already connected with tls")]
|
||||
AlreadyTls,
|
||||
// TODO: specify unsupported feature
|
||||
#[error("unsupported feature")]
|
||||
Unsupported,
|
||||
#[error("jid missing localpart")]
|
||||
NoLocalpart,
|
||||
AlreadyConnecting,
|
||||
StreamClosed,
|
||||
#[error("received unexpected element: {0:?}")]
|
||||
UnexpectedElement(peanuts::Element),
|
||||
XML(peanuts::Error),
|
||||
Deserialization(peanuts::DeserializeError),
|
||||
SASL(SASLError),
|
||||
JID(ParseError),
|
||||
Authentication(Failure),
|
||||
ClientError(ClientError),
|
||||
StreamError(StreamError),
|
||||
#[error("xml error: {0}")]
|
||||
XML(#[from] peanuts::Error),
|
||||
#[error("sasl error: {0}")]
|
||||
SASL(#[from] SASLError),
|
||||
#[error("jid error: {0}")]
|
||||
JID(#[from] ParseError),
|
||||
#[error("client stanza error: {0}")]
|
||||
ClientError(#[from] ClientError),
|
||||
#[error("stream error: {0}")]
|
||||
StreamError(#[from] StreamError),
|
||||
#[error("error missing")]
|
||||
MissingError,
|
||||
Disconnected,
|
||||
Connecting,
|
||||
JoinError(JoinError),
|
||||
#[error("task join error")]
|
||||
JoinError(#[from] JoinError),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SASLError {
|
||||
SASL(rsasl::prelude::SASLError),
|
||||
MechanismName(MechanismNameError),
|
||||
}
|
||||
|
||||
impl From<rsasl::prelude::SASLError> for Error {
|
||||
fn from(e: rsasl::prelude::SASLError) -> Self {
|
||||
Self::SASL(SASLError::SASL(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JoinError> for Error {
|
||||
fn from(e: JoinError) -> Self {
|
||||
Self::JoinError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<peanuts::DeserializeError> for Error {
|
||||
fn from(e: peanuts::DeserializeError) -> Self {
|
||||
Error::Deserialization(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MechanismNameError> for Error {
|
||||
fn from(e: MechanismNameError) -> Self {
|
||||
Self::SASL(SASLError::MechanismName(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SASLError> for Error {
|
||||
fn from(e: SASLError) -> Self {
|
||||
Self::SASL(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Utf8Error> for Error {
|
||||
fn from(_e: Utf8Error) -> Self {
|
||||
Self::Utf8Decode
|
||||
}
|
||||
}
|
||||
|
||||
impl From<peanuts::Error> for Error {
|
||||
fn from(e: peanuts::Error) -> Self {
|
||||
Self::XML(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseError> for Error {
|
||||
fn from(e: ParseError) -> Self {
|
||||
Self::JID(e)
|
||||
}
|
||||
#[error("sasl error: {0}")]
|
||||
SASL(#[from] rsasl::prelude::SASLError),
|
||||
#[error("mechanism error: {0}")]
|
||||
MechanismName(#[from] MechanismNameError),
|
||||
#[error("authentication failure: {0}")]
|
||||
Authentication(#[from] Failure),
|
||||
}
|
||||
|
|
|
@ -133,10 +133,13 @@ where
|
|||
let sasl = SASLClient::new(sasl_config);
|
||||
let mut offered_mechs: Vec<&Mechname> = Vec::new();
|
||||
for mechanism in &mechanisms.mechanisms {
|
||||
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
|
||||
offered_mechs
|
||||
.push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)
|
||||
}
|
||||
debug!("{:?}", offered_mechs);
|
||||
let mut session = sasl.start_suggested(&offered_mechs)?;
|
||||
let mut session = sasl
|
||||
.start_suggested(&offered_mechs)
|
||||
.map_err(|e| Error::SASL(e.into()))?;
|
||||
let selected_mechanism = session.get_mechname().as_str().to_owned();
|
||||
debug!("selected mech: {:?}", selected_mechanism);
|
||||
let mut data: Option<Vec<u8>>;
|
||||
|
@ -174,7 +177,7 @@ where
|
|||
ServerResponse::Success(success) => {
|
||||
data = success.clone().map(|success| success.as_bytes().to_vec())
|
||||
}
|
||||
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
|
||||
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
|
||||
}
|
||||
debug!("we went first");
|
||||
}
|
||||
|
@ -205,7 +208,7 @@ where
|
|||
ServerResponse::Success(success) => {
|
||||
data = success.clone().map(|success| success.as_bytes().to_vec())
|
||||
}
|
||||
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
|
||||
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue