add sasl failure type

This commit is contained in:
cel 🌸 2024-11-29 17:07:16 +00:00
parent b659338906
commit 859a19820d
4 changed files with 96 additions and 9 deletions

View File

@ -12,7 +12,7 @@ async-trait = "0.1.68"
lazy_static = "1.4.0" lazy_static = "1.4.0"
nanoid = "0.4.0" nanoid = "0.4.0"
# TODO: remove unneeded features # TODO: remove unneeded features
rsasl = { version = "2.0.1", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] } rsasl = { version = "2.0.1", path = "../rsasl", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] }
tokio = { version = "1.28", features = ["full"] } tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
tracing = "0.1.40" tracing = "0.1.40"

View File

@ -2,7 +2,7 @@ use std::str::Utf8Error;
use rsasl::mechname::MechanismNameError; use rsasl::mechname::MechanismNameError;
use crate::jid::ParseError; use crate::{jid::ParseError, stanza::sasl::Failure};
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
@ -27,6 +27,7 @@ pub enum Error {
XML(peanuts::Error), XML(peanuts::Error),
SASL(SASLError), SASL(SASLError),
JID(ParseError), JID(ParseError),
Authentication(Failure),
} }
#[derive(Debug)] #[derive(Debug)]

View File

@ -5,7 +5,7 @@ use async_recursion::async_recursion;
use peanuts::element::{FromElement, IntoElement}; use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer}; use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf};
use tokio::time::timeout; use tokio::time::timeout;
use tokio_native_tls::native_tls::TlsConnector; use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, instrument, trace}; use tracing::{debug, info, instrument, trace};
@ -102,7 +102,10 @@ where
ServerResponse::Challenge(challenge) => { ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec()) data = Some((*challenge).as_bytes().to_vec())
} }
ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
} }
debug!("we went first"); debug!("we went first");
} }
@ -121,7 +124,11 @@ where
// While we aren't finished, receive more data from the other party // While we aren't finished, receive more data from the other party
let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
debug!("response: {:?}", response); debug!("response: {:?}", response);
let stdout = tokio::io::stdout();
let mut writer = Writer::new(stdout);
writer.write_full(&response).await?;
self.writer.write_full(&response).await?; self.writer.write_full(&response).await?;
debug!("response written");
let server_response: ServerResponse = self.reader.read().await?; let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response); debug!("server_response: {:#?}", server_response);
@ -129,7 +136,10 @@ where
ServerResponse::Challenge(challenge) => { ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec()) data = Some((*challenge).as_bytes().to_vec())
} }
ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
} }
} }
} }

View File

@ -105,10 +105,10 @@ impl FromElement for Challenge {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Success(String); pub struct Success(Option<String>);
impl Deref for Success { impl Deref for Success {
type Target = str; type Target = Option<String>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
@ -120,7 +120,7 @@ impl FromElement for Success {
element.check_name("success")?; element.check_name("success")?;
element.check_namespace(XMLNS)?; element.check_namespace(XMLNS)?;
let sasl_data = element.value()?; let sasl_data = element.value_opt()?;
Ok(Success(sasl_data)) Ok(Success(sasl_data))
} }
@ -130,10 +130,12 @@ impl FromElement for Success {
pub enum ServerResponse { pub enum ServerResponse {
Challenge(Challenge), Challenge(Challenge),
Success(Success), Success(Success),
Failure(Failure),
} }
impl FromElement for ServerResponse { impl FromElement for ServerResponse {
fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
debug!("identification: {:?}", element.identify());
match element.identify() { match element.identify() {
(Some(XMLNS), "challenge") => { (Some(XMLNS), "challenge") => {
Ok(ServerResponse::Challenge(Challenge::from_element(element)?)) Ok(ServerResponse::Challenge(Challenge::from_element(element)?))
@ -141,6 +143,9 @@ impl FromElement for ServerResponse {
(Some(XMLNS), "success") => { (Some(XMLNS), "success") => {
Ok(ServerResponse::Success(Success::from_element(element)?)) Ok(ServerResponse::Success(Success::from_element(element)?))
} }
(Some(XMLNS), "failure") => {
Ok(ServerResponse::Failure(Failure::from_element(element)?))
}
_ => Err(DeserializeError::UnexpectedElement(element)), _ => Err(DeserializeError::UnexpectedElement(element)),
} }
} }
@ -165,6 +170,77 @@ impl Deref for Response {
impl IntoElement for Response { impl IntoElement for Response {
fn builder(&self) -> peanuts::element::ElementBuilder { fn builder(&self) -> peanuts::element::ElementBuilder {
Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone()) Element::builder("response", Some(XMLNS)).push_text(self.0.clone())
}
}
#[derive(Debug)]
pub struct Failure {
r#type: Option<FailureType>,
text: Option<Text>,
}
impl FromElement for Failure {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("failure")?;
element.check_namespace(XMLNS)?;
let r#type = element.pop_child_opt()?;
let text = element.pop_child_opt()?;
Ok(Failure { r#type, text })
}
}
#[derive(Debug)]
pub enum FailureType {
Aborted,
AccountDisabled,
CredentialsExpired,
EncryptionRequired,
IncorrectEncoding,
InvalidAuthzid,
InvalidMechanism,
MalformedRequest,
MechanismTooWeak,
NotAuthorized,
TemporaryAuthFailure,
}
impl FromElement for FailureType {
fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
match element.identify() {
(Some(XMLNS), "aborted") => Ok(FailureType::Aborted),
(Some(XMLNS), "account-disabled") => Ok(FailureType::AccountDisabled),
(Some(XMLNS), "credentials-expired") => Ok(FailureType::CredentialsExpired),
(Some(XMLNS), "encryption-required") => Ok(FailureType::EncryptionRequired),
(Some(XMLNS), "incorrect-encoding") => Ok(FailureType::IncorrectEncoding),
(Some(XMLNS), "invalid-authzid") => Ok(FailureType::InvalidAuthzid),
(Some(XMLNS), "invalid-mechanism") => Ok(FailureType::InvalidMechanism),
(Some(XMLNS), "malformed-request") => Ok(FailureType::MalformedRequest),
(Some(XMLNS), "mechanism-too-weak") => Ok(FailureType::MechanismTooWeak),
(Some(XMLNS), "not-authorized") => Ok(FailureType::NotAuthorized),
(Some(XMLNS), "temporary-auth-failure") => Ok(FailureType::TemporaryAuthFailure),
_ => Err(DeserializeError::UnexpectedElement(element)),
}
}
}
#[derive(Debug)]
pub struct Text {
lang: Option<String>,
text: Option<String>,
}
impl FromElement for Text {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("text")?;
element.check_namespace(XMLNS)?;
let lang = element.attribute_opt_namespaced("lang", peanuts::XML_NS)?;
let text = element.pop_value_opt()?;
Ok(Text { lang, text })
} }
} }