From 322b2a3b46348ec1c5acbc538de93310c9030b96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Wed, 12 Jul 2023 21:11:20 +0100 Subject: [PATCH] reimplement sasl (with SCRAM!) --- Cargo.toml | 2 +- TODO.md | 2 + src/client/encrypted.rs | 130 +++++++++++++++++++++++++----- src/client/mod.rs | 11 +-- src/client/unencrypted.rs | 8 +- src/error.rs | 13 ++- src/jabber.rs | 2 +- src/stanza/mod.rs | 74 ++++++++++++----- src/stanza/sasl.rs | 165 ++++++++++++++++++++++++++++++++++++-- src/stanza/stream.rs | 20 ++--- 10 files changed, 357 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 153f648..eb89659 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ async-recursion = "1.0.4" async-trait = "0.1.68" quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio"] } # TODO: remove unneeded features -rsasl = { version = "2", default_features = false, features = ["provider_base64", "plain", "config_builder"] } +rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } tokio = { version = "1.28", features = ["full"] } tokio-native-tls = "0.3.1" trust-dns-resolver = "0.22.0" diff --git a/TODO.md b/TODO.md index 068be75..22d656a 100644 --- a/TODO.md +++ b/TODO.md @@ -7,3 +7,5 @@ [ ] remove unwraps [ ] proper error types [ ] stream error type +[ ] change stanzas from owned to borrowed types with lifetimes +[ ] Into trait with event() and content() functions diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs index 898dc23..e8b7271 100644 --- a/src/client/encrypted.rs +++ b/src/client/encrypted.rs @@ -1,13 +1,23 @@ +use std::str; + use quick_xml::{ events::{BytesDecl, Event}, + name::QName, Reader, Writer, }; +use rsasl::prelude::{Mechname, SASLClient}; use tokio::io::{BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio_native_tls::TlsStream; -use crate::stanza::stream::{Stream, StreamFeature}; -use crate::stanza::Element; +use crate::stanza::{ + sasl::{Auth, Response}, + stream::{Stream, StreamFeature}, +}; +use crate::stanza::{ + sasl::{Challenge, Success}, + Element, +}; use crate::Jabber; use crate::Result; @@ -48,27 +58,111 @@ impl<'j> JabberClient<'j> { Ok(()) } - pub async fn get_features(&mut self) -> Result>> { - if let Some(features) = Element::read(&mut self.reader).await? { - Ok(Some(features.try_into()?)) - } else { - Ok(None) - } + pub async fn get_features(&mut self) -> Result> { + Element::read(&mut self.reader).await?.try_into() } pub async fn negotiate(&mut self) -> Result<()> { loop { println!("loop"); - let features = &self.get_features().await?; - println!("{:?}", features); - // match &features[0] { - // StreamFeature::Sasl(sasl) => { - // println!("{:?}", sasl); - // todo!() - // } - // StreamFeature::Bind => todo!(), - // x => println!("{:?}", x), - // } + let features = self.get_features().await?; + println!("features: {:?}", features); + match &features[0] { + StreamFeature::Sasl(sasl) => { + println!("sasl?"); + self.sasl(&sasl).await?; + } + StreamFeature::Bind => todo!(), + x => println!("{:?}", x), + } } } + + pub async fn sasl(&mut self, mechanisms: &Vec) -> Result<()> { + println!("{:?}", mechanisms); + let sasl = SASLClient::new(self.jabber.auth.clone()); + let mut offered_mechs: Vec<&Mechname> = Vec::new(); + for mechanism in mechanisms { + offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) + } + println!("{:?}", offered_mechs); + let mut session = sasl.start_suggested(&offered_mechs)?; + let selected_mechanism = session.get_mechname().as_str().to_owned(); + println!("selected mech: {:?}", selected_mechanism); + let mut data: Option> = None; + if !session.are_we_first() { + // if not first mention the mechanism then get challenge data + // mention mechanism + let auth = Auth { + mechanism: selected_mechanism.as_str(), + sasl_data: "=", + }; + Into::::into(auth).write(&mut self.writer).await?; + // get challenge data + let challenge = &Element::read(&mut self.reader).await?; + let challenge: Challenge = challenge.try_into()?; + println!("challenge: {:?}", challenge); + data = Some(challenge.sasl_data.to_owned()); + println!("we didn't go first"); + } else { + // if first, mention mechanism and send data + let mut sasl_data = Vec::new(); + session.step64(None, &mut sasl_data).unwrap(); + let auth = Auth { + mechanism: selected_mechanism.as_str(), + sasl_data: str::from_utf8(&sasl_data)?, + }; + println!("{:?}", auth); + Into::::into(auth).write(&mut self.writer).await?; + + let server_response = Element::read(&mut self.reader).await?; + println!("server_response: {:#?}", server_response); + match TryInto::::try_into(&server_response) { + Ok(challenge) => data = Some(challenge.sasl_data.to_owned()), + Err(_) => { + let success = TryInto::::try_into(&server_response)?; + if let Some(sasl_data) = success.sasl_data { + data = Some(sasl_data.to_owned()) + } + } + } + println!("we went first"); + } + + // stepping the authentication exchange to completion + if data != None { + println!("data: {:?}", data); + let mut sasl_data = Vec::new(); + while { + // decide if need to send more data over + let state = session + .step64(data.as_deref(), &mut sasl_data) + .expect("step errored!"); + state.is_running() + } { + // While we aren't finished, receive more data from the other party + let response = Response { + sasl_data: str::from_utf8(&sasl_data)?, + }; + println!("response: {:?}", response); + Into::::into(response) + .write(&mut self.writer) + .await?; + + let server_response = Element::read(&mut self.reader).await?; + println!("server_response: {:?}", server_response); + match TryInto::::try_into(&server_response) { + Ok(challenge) => data = Some(challenge.sasl_data.to_owned()), + Err(_) => { + let success = TryInto::::try_into(&server_response)?; + if let Some(sasl_data) = success.sasl_data { + data = Some(sasl_data.to_owned()) + } + } + } + } + } + self.start_stream().await?; + Ok(()) + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index d545923..280e0a1 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -17,14 +17,11 @@ impl<'j> JabberClientType<'j> { match self { Self::Encrypted(c) => Ok(c), Self::Unencrypted(mut c) => { - if let Some(features) = c.get_features().await? { - if features.contains(&StreamFeature::StartTls) { - Ok(c.starttls().await?) - } else { - Err(JabberError::StartTlsUnavailable) - } + let features = c.get_features().await?; + if features.contains(&StreamFeature::StartTls) { + Ok(c.starttls().await?) } else { - Err(JabberError::NoFeatures) + Err(JabberError::StartTlsUnavailable) } } } diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs index dcd10c6..27b0a5f 100644 --- a/src/client/unencrypted.rs +++ b/src/client/unencrypted.rs @@ -50,12 +50,8 @@ impl<'j> JabberClient<'j> { Ok(()) } - pub async fn get_features(&mut self) -> Result>> { - if let Some(features) = Element::read(&mut self.reader).await? { - Ok(Some(features.try_into()?)) - } else { - Ok(None) - } + pub async fn get_features(&mut self) -> Result> { + Element::read(&mut self.reader).await?.try_into() } pub async fn starttls(mut self) -> Result> { diff --git a/src/error.rs b/src/error.rs index 7f704e5..17bfbef 100644 --- a/src/error.rs +++ b/src/error.rs @@ -18,6 +18,7 @@ pub enum JabberError { NoFeatures, UnknownNamespace, ParseError, + UnexpectedEnd, XML(quick_xml::Error), SASL(SASLError), Element(ElementError<'static>), @@ -28,6 +29,8 @@ pub enum JabberError { pub enum SASLError { SASL(rsasl::prelude::SASLError), MechanismName(MechanismNameError), + NoChallenge, + NoSuccess, } impl From for JabberError { @@ -37,8 +40,14 @@ impl From for JabberError { } impl From for JabberError { - fn from(value: MechanismNameError) -> Self { - Self::SASL(SASLError::MechanismName(value)) + fn from(e: MechanismNameError) -> Self { + Self::SASL(SASLError::MechanismName(e)) + } +} + +impl From for JabberError { + fn from(e: SASLError) -> Self { + Self::SASL(e) } } diff --git a/src/jabber.rs b/src/jabber.rs index a48751c..1a7eddb 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -24,7 +24,7 @@ pub struct Jabber<'j> { impl<'j> Jabber<'j> { pub fn new(jid: JID, password: String) -> Result { let server = jid.domainpart.clone(); - let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?; + let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; println!("auth: {:?}", auth); Ok(Self { jid, diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index 16f3bdd..c29b1a2 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -9,12 +9,12 @@ use quick_xml::events::Event; use quick_xml::{Reader, Writer}; use tokio::io::{AsyncBufRead, AsyncWrite}; -use crate::Result; +use crate::JabberError; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Element<'e> { pub event: Event<'e>, - pub content: Option>>, + pub children: Option>>, } impl<'e: 'async_recursion, 'async_recursion> Element<'e> { @@ -23,7 +23,7 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> { writer: &'life0 mut Writer, ) -> ::core::pin::Pin< Box< - dyn ::core::future::Future> + dyn ::core::future::Future> + 'async_recursion + ::core::marker::Send, >, @@ -36,9 +36,9 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> { match &self.event { Event::Start(e) => { writer.write_event_async(Event::Start(e.clone())).await?; - if let Some(content) = &self.content { - for _e in content { - self.write(writer).await?; + if let Some(children) = &self.children { + for e in children { + e.write(writer).await?; } } writer.write_event_async(Event::End(e.to_end())).await?; @@ -54,7 +54,7 @@ impl<'e> Element<'e> { pub async fn write_start( &self, writer: &mut Writer, - ) -> Result<()> { + ) -> Result<(), JabberError> { match self.event.as_ref() { Event::Start(e) => Ok(writer.write_event_async(Event::Start(e.clone())).await?), e => Err(ElementError::NotAStart(e.clone().into_owned()).into()), @@ -64,7 +64,7 @@ impl<'e> Element<'e> { pub async fn write_end( &self, writer: &mut Writer, - ) -> Result<()> { + ) -> Result<(), JabberError> { match self.event.as_ref() { Event::Start(e) => Ok(writer .write_event_async(Event::End(e.clone().to_end())) @@ -76,28 +76,38 @@ impl<'e> Element<'e> { #[async_recursion] pub async fn read( reader: &mut Reader, - ) -> Result> { + ) -> Result { + let element = Self::read_recursive(reader) + .await? + .ok_or(JabberError::UnexpectedEnd); + element + } + + #[async_recursion] + async fn read_recursive( + reader: &mut Reader, + ) -> Result, JabberError> { let mut buf = Vec::new(); let event = reader.read_event_into_async(&mut buf).await?; match event { Event::Start(e) => { - let mut content_vec = Vec::new(); - while let Some(sub_element) = Element::read(reader).await? { - content_vec.push(sub_element) + let mut children_vec = Vec::new(); + while let Some(sub_element) = Element::read_recursive(reader).await? { + children_vec.push(sub_element) } - let mut content = None; - if !content_vec.is_empty() { - content = Some(content_vec) + let mut children = None; + if !children_vec.is_empty() { + children = Some(children_vec) } Ok(Some(Self { event: Event::Start(e.into_owned()), - content, + children, })) } Event::End(_) => Ok(None), e => Ok(Some(Self { event: e.into_owned(), - content: None, + children: None, })), } } @@ -105,14 +115,14 @@ impl<'e> Element<'e> { #[async_recursion] pub async fn read_start( reader: &mut Reader, - ) -> Result { + ) -> Result { let mut buf = Vec::new(); let event = reader.read_event_into_async(&mut buf).await?; match event { Event::Start(e) => { return Ok(Self { event: Event::Start(e.into_owned()), - content: None, + children: None, }) } e => Err(ElementError::NotAStart(e.into_owned()).into()), @@ -120,7 +130,31 @@ impl<'e> Element<'e> { } } +/// if there is only one child in the vec of children, will return that element +pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, ElementError<'static>> { + if let Some(children) = &element.children { + if children.len() == 1 { + return Ok(&children[0]); + } else { + return Err(ElementError::MultipleChildren); + } + } + Err(ElementError::NoChildren) +} + +/// returns reference to children +pub fn children<'p, 'e>( + element: &'p Element<'e>, +) -> Result<&'p Vec>, ElementError<'e>> { + if let Some(children) = &element.children { + return Ok(children); + } + Err(ElementError::NoChildren) +} + #[derive(Debug)] pub enum ElementError<'e> { NotAStart(Event<'e>), + NoChildren, + MultipleChildren, } diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 1f77ffa..bbf3f41 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -1,8 +1,163 @@ -pub struct Auth { - pub mechanism: String, - pub sasl_data: Option, +use quick_xml::{ + events::{BytesStart, BytesText, Event}, + name::QName, +}; + +use crate::error::SASLError; +use crate::JabberError; + +use super::Element; + +const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; + +#[derive(Debug)] +pub struct Auth<'e> { + pub mechanism: &'e str, + pub sasl_data: &'e str, } -pub struct Challenge { - pub sasl_data: String, +impl<'e> Auth<'e> { + fn event(&self) -> Event<'e> { + let mut start = BytesStart::new("auth"); + start.push_attribute(("xmlns", XMLNS)); + start.push_attribute(("mechanism", self.mechanism)); + Event::Start(start) + } + + fn children(&self) -> Option>> { + let sasl = BytesText::from_escaped(self.sasl_data); + let sasl = Element { + event: Event::Text(sasl), + children: None, + }; + Some(vec![sasl]) + } +} + +impl<'e> Into> for Auth<'e> { + fn into(self) -> Element<'e> { + Element { + event: self.event(), + children: self.children(), + } + } +} + +#[derive(Debug)] +pub struct Challenge { + pub sasl_data: Vec, +} + +impl<'e> TryFrom<&Element<'e>> for Challenge { + type Error = JabberError; + + fn try_from(element: &Element<'e>) -> Result { + if let Event::Start(start) = &element.event { + if start.name() == QName(b"challenge") { + let sasl_data: &Element<'_> = super::child(element)?; + if let Event::Text(sasl_data) = &sasl_data.event { + let s = sasl_data.clone(); + let s = s.into_inner(); + let s = s.to_vec(); + return Ok(Challenge { sasl_data: s }); + } + } + } + Err(SASLError::NoChallenge.into()) + } +} + +// impl<'e> TryFrom> for Challenge { +// type Error = JabberError; + +// fn try_from(element: Element<'e>) -> Result { +// if let Event::Start(start) = &element.event { +// if start.name() == QName(b"challenge") { +// println!("one"); +// if let Some(children) = element.children.as_deref() { +// if children.len() == 1 { +// let sasl_data = children.first().unwrap(); +// if let Event::Text(sasl_data) = &sasl_data.event { +// return Ok(Challenge { +// sasl_data: sasl_data.clone().into_inner().to_vec(), +// }); +// } else { +// return Err(SASLError::NoChallenge.into()); +// } +// } else { +// return Err(SASLError::NoChallenge.into()); +// } +// } else { +// return Err(SASLError::NoChallenge.into()); +// } +// } +// } +// Err(SASLError::NoChallenge.into()) +// } +// } + +#[derive(Debug)] +pub struct Response<'e> { + pub sasl_data: &'e str, +} + +impl<'e> Response<'e> { + fn event(&self) -> Event<'e> { + let mut start = BytesStart::new("response"); + start.push_attribute(("xmlns", XMLNS)); + Event::Start(start) + } + + fn children(&self) -> Option>> { + let sasl = BytesText::from_escaped(self.sasl_data); + let sasl = Element { + event: Event::Text(sasl), + children: None, + }; + Some(vec![sasl]) + } +} + +impl<'e> Into> for Response<'e> { + fn into(self) -> Element<'e> { + Element { + event: self.event(), + children: self.children(), + } + } +} + +#[derive(Debug)] +pub struct Success { + pub sasl_data: Option>, +} + +impl<'e> TryFrom<&Element<'e>> for Success { + type Error = JabberError; + + fn try_from(element: &Element<'e>) -> Result { + match &element.event { + Event::Start(start) => { + if start.name() == QName(b"success") { + match super::child(element) { + Ok(sasl_data) => { + if let Event::Text(sasl_data) = &sasl_data.event { + return Ok(Success { + sasl_data: Some(sasl_data.clone().into_inner().to_vec()), + }); + } + } + Err(_) => return Ok(Success { sasl_data: None }), + }; + } + } + Event::Empty(empty) => { + if empty.name() == QName(b"success") { + return Ok(Success { sasl_data: None }); + } + } + _ => {} + } + Err(SASLError::NoSuccess.into()) + } } diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 32f449d..66741b8 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -58,7 +58,7 @@ impl Stream { } } - fn build(&self) -> BytesStart { + fn event(&self) -> Event<'static> { let mut start = BytesStart::new("stream:stream"); if let Some(from) = &self.from { start.push_attribute(("from", from.to_string().as_str())); @@ -80,15 +80,15 @@ impl Stream { XMLNS::Server => start.push_attribute(("xmlns", XMLNS::Server.into())), } start.push_attribute(("xmlns:stream", XMLNS_STREAM)); - start + Event::Start(start) } } impl<'e> Into> for Stream { fn into(self) -> Element<'e> { Element { - event: Event::Start(self.build().to_owned()), - content: None, + event: self.event(), + children: None, } } } @@ -153,17 +153,17 @@ impl<'e> TryFrom> for Vec { fn try_from(features_element: Element) -> Result { let mut features = Vec::new(); - if let Some(content) = features_element.content { - for feature_element in content { + if let Some(children) = features_element.children { + for feature_element in children { match feature_element.event { Event::Start(e) => match e.name() { QName(b"starttls") => features.push(StreamFeature::StartTls), QName(b"mechanisms") => { let mut mechanisms = Vec::new(); - if let Some(content) = feature_element.content { - for mechanism_element in content { - if let Some(content) = mechanism_element.content { - for mechanism_text in content { + if let Some(children) = feature_element.children { + for mechanism_element in children { + if let Some(children) = mechanism_element.children { + for mechanism_text in children { match mechanism_text.event { Event::Text(e) => mechanisms .push(str::from_utf8(e.as_ref())?.to_owned()),