From 859a19820d69eca5fca87fc01acad72a6355f97e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Fri, 29 Nov 2024 17:07:16 +0000 Subject: [PATCH] add sasl failure type --- Cargo.toml | 2 +- src/error.rs | 3 +- src/jabber.rs | 16 +++++++-- src/stanza/sasl.rs | 84 +++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 96 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 326e45e..e9c12b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ async-trait = "0.1.68" lazy_static = "1.4.0" nanoid = "0.4.0" # TODO: remove unneeded features -rsasl = { version = "2.0.1", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] } +rsasl = { version = "2.0.1", path = "../rsasl", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] } tokio = { version = "1.28", features = ["full"] } tokio-native-tls = "0.3.1" tracing = "0.1.40" diff --git a/src/error.rs b/src/error.rs index 8ee9077..a1f853b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use std::str::Utf8Error; use rsasl::mechname::MechanismNameError; -use crate::jid::ParseError; +use crate::{jid::ParseError, stanza::sasl::Failure}; #[derive(Debug)] pub enum Error { @@ -27,6 +27,7 @@ pub enum Error { XML(peanuts::Error), SASL(SASLError), JID(ParseError), + Authentication(Failure), } #[derive(Debug)] diff --git a/src/jabber.rs b/src/jabber.rs index 9e7f9d8..599879d 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -5,7 +5,7 @@ use async_recursion::async_recursion; use peanuts::element::{FromElement, IntoElement}; use peanuts::{Reader, Writer}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf}; use tokio::time::timeout; use tokio_native_tls::native_tls::TlsConnector; use tracing::{debug, info, instrument, trace}; @@ -102,7 +102,10 @@ where ServerResponse::Challenge(challenge) => { data = Some((*challenge).as_bytes().to_vec()) } - ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), + ServerResponse::Success(success) => { + data = success.clone().map(|success| success.as_bytes().to_vec()) + } + ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), } debug!("we went first"); } @@ -121,7 +124,11 @@ where // While we aren't finished, receive more data from the other party let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); debug!("response: {:?}", response); + let stdout = tokio::io::stdout(); + let mut writer = Writer::new(stdout); + writer.write_full(&response).await?; self.writer.write_full(&response).await?; + debug!("response written"); let server_response: ServerResponse = self.reader.read().await?; debug!("server_response: {:#?}", server_response); @@ -129,7 +136,10 @@ where ServerResponse::Challenge(challenge) => { data = Some((*challenge).as_bytes().to_vec()) } - ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), + ServerResponse::Success(success) => { + data = success.clone().map(|success| success.as_bytes().to_vec()) + } + ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), } } } diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 6ac4fc9..ec6f63c 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -105,10 +105,10 @@ impl FromElement for Challenge { } #[derive(Debug)] -pub struct Success(String); +pub struct Success(Option); impl Deref for Success { - type Target = str; + type Target = Option; fn deref(&self) -> &Self::Target { &self.0 @@ -120,7 +120,7 @@ impl FromElement for Success { element.check_name("success")?; element.check_namespace(XMLNS)?; - let sasl_data = element.value()?; + let sasl_data = element.value_opt()?; Ok(Success(sasl_data)) } @@ -130,10 +130,12 @@ impl FromElement for Success { pub enum ServerResponse { Challenge(Challenge), Success(Success), + Failure(Failure), } impl FromElement for ServerResponse { fn from_element(element: Element) -> peanuts::element::DeserializeResult { + debug!("identification: {:?}", element.identify()); match element.identify() { (Some(XMLNS), "challenge") => { Ok(ServerResponse::Challenge(Challenge::from_element(element)?)) @@ -141,6 +143,9 @@ impl FromElement for ServerResponse { (Some(XMLNS), "success") => { Ok(ServerResponse::Success(Success::from_element(element)?)) } + (Some(XMLNS), "failure") => { + Ok(ServerResponse::Failure(Failure::from_element(element)?)) + } _ => Err(DeserializeError::UnexpectedElement(element)), } } @@ -165,6 +170,77 @@ impl Deref for Response { impl IntoElement for Response { fn builder(&self) -> peanuts::element::ElementBuilder { - Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone()) + Element::builder("response", Some(XMLNS)).push_text(self.0.clone()) + } +} + +#[derive(Debug)] +pub struct Failure { + r#type: Option, + text: Option, +} + +impl FromElement for Failure { + fn from_element(mut element: Element) -> peanuts::element::DeserializeResult { + element.check_name("failure")?; + element.check_namespace(XMLNS)?; + + let r#type = element.pop_child_opt()?; + let text = element.pop_child_opt()?; + + Ok(Failure { r#type, text }) + } +} + +#[derive(Debug)] +pub enum FailureType { + Aborted, + AccountDisabled, + CredentialsExpired, + EncryptionRequired, + IncorrectEncoding, + InvalidAuthzid, + InvalidMechanism, + MalformedRequest, + MechanismTooWeak, + NotAuthorized, + TemporaryAuthFailure, +} + +impl FromElement for FailureType { + fn from_element(element: Element) -> peanuts::element::DeserializeResult { + match element.identify() { + (Some(XMLNS), "aborted") => Ok(FailureType::Aborted), + (Some(XMLNS), "account-disabled") => Ok(FailureType::AccountDisabled), + (Some(XMLNS), "credentials-expired") => Ok(FailureType::CredentialsExpired), + (Some(XMLNS), "encryption-required") => Ok(FailureType::EncryptionRequired), + (Some(XMLNS), "incorrect-encoding") => Ok(FailureType::IncorrectEncoding), + (Some(XMLNS), "invalid-authzid") => Ok(FailureType::InvalidAuthzid), + (Some(XMLNS), "invalid-mechanism") => Ok(FailureType::InvalidMechanism), + (Some(XMLNS), "malformed-request") => Ok(FailureType::MalformedRequest), + (Some(XMLNS), "mechanism-too-weak") => Ok(FailureType::MechanismTooWeak), + (Some(XMLNS), "not-authorized") => Ok(FailureType::NotAuthorized), + (Some(XMLNS), "temporary-auth-failure") => Ok(FailureType::TemporaryAuthFailure), + _ => Err(DeserializeError::UnexpectedElement(element)), + } + } +} + +#[derive(Debug)] +pub struct Text { + lang: Option, + text: Option, +} + +impl FromElement for Text { + fn from_element(mut element: Element) -> peanuts::element::DeserializeResult { + element.check_name("text")?; + element.check_namespace(XMLNS)?; + + let lang = element.attribute_opt_namespaced("lang", peanuts::XML_NS)?; + + let text = element.pop_value_opt()?; + + Ok(Text { lang, text }) } }