reimplement sasl (with SCRAM!)

This commit is contained in:
cel 🌸 2023-07-12 21:11:20 +01:00
parent c9683935f1
commit 322b2a3b46
Signed by: cel
GPG Key ID: 48E29AF13B5F1349
10 changed files with 357 additions and 70 deletions

View File

@ -11,7 +11,7 @@ async-recursion = "1.0.4"
async-trait = "0.1.68" async-trait = "0.1.68"
quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio"] } quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio"] }
# TODO: remove unneeded features # TODO: remove unneeded features
rsasl = { version = "2", default_features = false, features = ["provider_base64", "plain", "config_builder"] } rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }
tokio = { version = "1.28", features = ["full"] } tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
trust-dns-resolver = "0.22.0" trust-dns-resolver = "0.22.0"

View File

@ -7,3 +7,5 @@
[ ] remove unwraps [ ] remove unwraps
[ ] proper error types [ ] proper error types
[ ] stream error type [ ] stream error type
[ ] change stanzas from owned to borrowed types with lifetimes
[ ] Into<Element> trait with event() and content() functions

View File

@ -1,13 +1,23 @@
use std::str;
use quick_xml::{ use quick_xml::{
events::{BytesDecl, Event}, events::{BytesDecl, Event},
name::QName,
Reader, Writer, Reader, Writer,
}; };
use rsasl::prelude::{Mechname, SASLClient};
use tokio::io::{BufReader, ReadHalf, WriteHalf}; use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_native_tls::TlsStream; use tokio_native_tls::TlsStream;
use crate::stanza::stream::{Stream, StreamFeature}; use crate::stanza::{
use crate::stanza::Element; sasl::{Auth, Response},
stream::{Stream, StreamFeature},
};
use crate::stanza::{
sasl::{Challenge, Success},
Element,
};
use crate::Jabber; use crate::Jabber;
use crate::Result; use crate::Result;
@ -48,27 +58,111 @@ impl<'j> JabberClient<'j> {
Ok(()) Ok(())
} }
pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> { pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
if let Some(features) = Element::read(&mut self.reader).await? { Element::read(&mut self.reader).await?.try_into()
Ok(Some(features.try_into()?))
} else {
Ok(None)
}
} }
pub async fn negotiate(&mut self) -> Result<()> { pub async fn negotiate(&mut self) -> Result<()> {
loop { loop {
println!("loop"); println!("loop");
let features = &self.get_features().await?; let features = self.get_features().await?;
println!("{:?}", features); println!("features: {:?}", features);
// match &features[0] { match &features[0] {
// StreamFeature::Sasl(sasl) => { StreamFeature::Sasl(sasl) => {
// println!("{:?}", sasl); println!("sasl?");
// todo!() self.sasl(&sasl).await?;
// } }
// StreamFeature::Bind => todo!(), StreamFeature::Bind => todo!(),
// x => println!("{:?}", x), x => println!("{:?}", x),
// } }
} }
} }
pub async fn sasl(&mut self, mechanisms: &Vec<String>) -> Result<()> {
println!("{:?}", mechanisms);
let sasl = SASLClient::new(self.jabber.auth.clone());
let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in mechanisms {
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
}
println!("{:?}", offered_mechs);
let mut session = sasl.start_suggested(&offered_mechs)?;
let selected_mechanism = session.get_mechname().as_str().to_owned();
println!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>> = None;
if !session.are_we_first() {
// if not first mention the mechanism then get challenge data
// mention mechanism
let auth = Auth {
mechanism: selected_mechanism.as_str(),
sasl_data: "=",
};
Into::<Element>::into(auth).write(&mut self.writer).await?;
// get challenge data
let challenge = &Element::read(&mut self.reader).await?;
let challenge: Challenge = challenge.try_into()?;
println!("challenge: {:?}", challenge);
data = Some(challenge.sasl_data.to_owned());
println!("we didn't go first");
} else {
// if first, mention mechanism and send data
let mut sasl_data = Vec::new();
session.step64(None, &mut sasl_data).unwrap();
let auth = Auth {
mechanism: selected_mechanism.as_str(),
sasl_data: str::from_utf8(&sasl_data)?,
};
println!("{:?}", auth);
Into::<Element>::into(auth).write(&mut self.writer).await?;
let server_response = Element::read(&mut self.reader).await?;
println!("server_response: {:#?}", server_response);
match TryInto::<Challenge>::try_into(&server_response) {
Ok(challenge) => data = Some(challenge.sasl_data.to_owned()),
Err(_) => {
let success = TryInto::<Success>::try_into(&server_response)?;
if let Some(sasl_data) = success.sasl_data {
data = Some(sasl_data.to_owned())
}
}
}
println!("we went first");
}
// stepping the authentication exchange to completion
if data != None {
println!("data: {:?}", data);
let mut sasl_data = Vec::new();
while {
// decide if need to send more data over
let state = session
.step64(data.as_deref(), &mut sasl_data)
.expect("step errored!");
state.is_running()
} {
// While we aren't finished, receive more data from the other party
let response = Response {
sasl_data: str::from_utf8(&sasl_data)?,
};
println!("response: {:?}", response);
Into::<Element>::into(response)
.write(&mut self.writer)
.await?;
let server_response = Element::read(&mut self.reader).await?;
println!("server_response: {:?}", server_response);
match TryInto::<Challenge>::try_into(&server_response) {
Ok(challenge) => data = Some(challenge.sasl_data.to_owned()),
Err(_) => {
let success = TryInto::<Success>::try_into(&server_response)?;
if let Some(sasl_data) = success.sasl_data {
data = Some(sasl_data.to_owned())
}
}
}
}
}
self.start_stream().await?;
Ok(())
}
} }

View File

@ -17,14 +17,11 @@ impl<'j> JabberClientType<'j> {
match self { match self {
Self::Encrypted(c) => Ok(c), Self::Encrypted(c) => Ok(c),
Self::Unencrypted(mut c) => { Self::Unencrypted(mut c) => {
if let Some(features) = c.get_features().await? { let features = c.get_features().await?;
if features.contains(&StreamFeature::StartTls) { if features.contains(&StreamFeature::StartTls) {
Ok(c.starttls().await?) Ok(c.starttls().await?)
} else {
Err(JabberError::StartTlsUnavailable)
}
} else { } else {
Err(JabberError::NoFeatures) Err(JabberError::StartTlsUnavailable)
} }
} }
} }

View File

@ -50,12 +50,8 @@ impl<'j> JabberClient<'j> {
Ok(()) Ok(())
} }
pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> { pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
if let Some(features) = Element::read(&mut self.reader).await? { Element::read(&mut self.reader).await?.try_into()
Ok(Some(features.try_into()?))
} else {
Ok(None)
}
} }
pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> { pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> {

View File

@ -18,6 +18,7 @@ pub enum JabberError {
NoFeatures, NoFeatures,
UnknownNamespace, UnknownNamespace,
ParseError, ParseError,
UnexpectedEnd,
XML(quick_xml::Error), XML(quick_xml::Error),
SASL(SASLError), SASL(SASLError),
Element(ElementError<'static>), Element(ElementError<'static>),
@ -28,6 +29,8 @@ pub enum JabberError {
pub enum SASLError { pub enum SASLError {
SASL(rsasl::prelude::SASLError), SASL(rsasl::prelude::SASLError),
MechanismName(MechanismNameError), MechanismName(MechanismNameError),
NoChallenge,
NoSuccess,
} }
impl From<rsasl::prelude::SASLError> for JabberError { impl From<rsasl::prelude::SASLError> for JabberError {
@ -37,8 +40,14 @@ impl From<rsasl::prelude::SASLError> for JabberError {
} }
impl From<MechanismNameError> for JabberError { impl From<MechanismNameError> for JabberError {
fn from(value: MechanismNameError) -> Self { fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(value)) Self::SASL(SASLError::MechanismName(e))
}
}
impl From<SASLError> for JabberError {
fn from(e: SASLError) -> Self {
Self::SASL(e)
} }
} }

View File

@ -24,7 +24,7 @@ pub struct Jabber<'j> {
impl<'j> Jabber<'j> { impl<'j> Jabber<'j> {
pub fn new(jid: JID, password: String) -> Result<Self> { pub fn new(jid: JID, password: String) -> Result<Self> {
let server = jid.domainpart.clone(); let server = jid.domainpart.clone();
let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?; let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
println!("auth: {:?}", auth); println!("auth: {:?}", auth);
Ok(Self { Ok(Self {
jid, jid,

View File

@ -9,12 +9,12 @@ use quick_xml::events::Event;
use quick_xml::{Reader, Writer}; use quick_xml::{Reader, Writer};
use tokio::io::{AsyncBufRead, AsyncWrite}; use tokio::io::{AsyncBufRead, AsyncWrite};
use crate::Result; use crate::JabberError;
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct Element<'e> { pub struct Element<'e> {
pub event: Event<'e>, pub event: Event<'e>,
pub content: Option<Vec<Element<'e>>>, pub children: Option<Vec<Element<'e>>>,
} }
impl<'e: 'async_recursion, 'async_recursion> Element<'e> { impl<'e: 'async_recursion, 'async_recursion> Element<'e> {
@ -23,7 +23,7 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {
writer: &'life0 mut Writer<W>, writer: &'life0 mut Writer<W>,
) -> ::core::pin::Pin< ) -> ::core::pin::Pin<
Box< Box<
dyn ::core::future::Future<Output = Result<()>> dyn ::core::future::Future<Output = Result<(), JabberError>>
+ 'async_recursion + 'async_recursion
+ ::core::marker::Send, + ::core::marker::Send,
>, >,
@ -36,9 +36,9 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {
match &self.event { match &self.event {
Event::Start(e) => { Event::Start(e) => {
writer.write_event_async(Event::Start(e.clone())).await?; writer.write_event_async(Event::Start(e.clone())).await?;
if let Some(content) = &self.content { if let Some(children) = &self.children {
for _e in content { for e in children {
self.write(writer).await?; e.write(writer).await?;
} }
} }
writer.write_event_async(Event::End(e.to_end())).await?; writer.write_event_async(Event::End(e.to_end())).await?;
@ -54,7 +54,7 @@ impl<'e> Element<'e> {
pub async fn write_start<W: AsyncWrite + Unpin + Send>( pub async fn write_start<W: AsyncWrite + Unpin + Send>(
&self, &self,
writer: &mut Writer<W>, writer: &mut Writer<W>,
) -> Result<()> { ) -> Result<(), JabberError> {
match self.event.as_ref() { match self.event.as_ref() {
Event::Start(e) => Ok(writer.write_event_async(Event::Start(e.clone())).await?), Event::Start(e) => Ok(writer.write_event_async(Event::Start(e.clone())).await?),
e => Err(ElementError::NotAStart(e.clone().into_owned()).into()), e => Err(ElementError::NotAStart(e.clone().into_owned()).into()),
@ -64,7 +64,7 @@ impl<'e> Element<'e> {
pub async fn write_end<W: AsyncWrite + Unpin + Send>( pub async fn write_end<W: AsyncWrite + Unpin + Send>(
&self, &self,
writer: &mut Writer<W>, writer: &mut Writer<W>,
) -> Result<()> { ) -> Result<(), JabberError> {
match self.event.as_ref() { match self.event.as_ref() {
Event::Start(e) => Ok(writer Event::Start(e) => Ok(writer
.write_event_async(Event::End(e.clone().to_end())) .write_event_async(Event::End(e.clone().to_end()))
@ -76,28 +76,38 @@ impl<'e> Element<'e> {
#[async_recursion] #[async_recursion]
pub async fn read<R: AsyncBufRead + Unpin + Send>( pub async fn read<R: AsyncBufRead + Unpin + Send>(
reader: &mut Reader<R>, reader: &mut Reader<R>,
) -> Result<Option<Self>> { ) -> Result<Self, JabberError> {
let element = Self::read_recursive(reader)
.await?
.ok_or(JabberError::UnexpectedEnd);
element
}
#[async_recursion]
async fn read_recursive<R: AsyncBufRead + Unpin + Send>(
reader: &mut Reader<R>,
) -> Result<Option<Self>, JabberError> {
let mut buf = Vec::new(); let mut buf = Vec::new();
let event = reader.read_event_into_async(&mut buf).await?; let event = reader.read_event_into_async(&mut buf).await?;
match event { match event {
Event::Start(e) => { Event::Start(e) => {
let mut content_vec = Vec::new(); let mut children_vec = Vec::new();
while let Some(sub_element) = Element::read(reader).await? { while let Some(sub_element) = Element::read_recursive(reader).await? {
content_vec.push(sub_element) children_vec.push(sub_element)
} }
let mut content = None; let mut children = None;
if !content_vec.is_empty() { if !children_vec.is_empty() {
content = Some(content_vec) children = Some(children_vec)
} }
Ok(Some(Self { Ok(Some(Self {
event: Event::Start(e.into_owned()), event: Event::Start(e.into_owned()),
content, children,
})) }))
} }
Event::End(_) => Ok(None), Event::End(_) => Ok(None),
e => Ok(Some(Self { e => Ok(Some(Self {
event: e.into_owned(), event: e.into_owned(),
content: None, children: None,
})), })),
} }
} }
@ -105,14 +115,14 @@ impl<'e> Element<'e> {
#[async_recursion] #[async_recursion]
pub async fn read_start<R: AsyncBufRead + Unpin + Send>( pub async fn read_start<R: AsyncBufRead + Unpin + Send>(
reader: &mut Reader<R>, reader: &mut Reader<R>,
) -> Result<Self> { ) -> Result<Self, JabberError> {
let mut buf = Vec::new(); let mut buf = Vec::new();
let event = reader.read_event_into_async(&mut buf).await?; let event = reader.read_event_into_async(&mut buf).await?;
match event { match event {
Event::Start(e) => { Event::Start(e) => {
return Ok(Self { return Ok(Self {
event: Event::Start(e.into_owned()), event: Event::Start(e.into_owned()),
content: None, children: None,
}) })
} }
e => Err(ElementError::NotAStart(e.into_owned()).into()), e => Err(ElementError::NotAStart(e.into_owned()).into()),
@ -120,7 +130,31 @@ impl<'e> Element<'e> {
} }
} }
/// if there is only one child in the vec of children, will return that element
pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, ElementError<'static>> {
if let Some(children) = &element.children {
if children.len() == 1 {
return Ok(&children[0]);
} else {
return Err(ElementError::MultipleChildren);
}
}
Err(ElementError::NoChildren)
}
/// returns reference to children
pub fn children<'p, 'e>(
element: &'p Element<'e>,
) -> Result<&'p Vec<Element<'e>>, ElementError<'e>> {
if let Some(children) = &element.children {
return Ok(children);
}
Err(ElementError::NoChildren)
}
#[derive(Debug)] #[derive(Debug)]
pub enum ElementError<'e> { pub enum ElementError<'e> {
NotAStart(Event<'e>), NotAStart(Event<'e>),
NoChildren,
MultipleChildren,
} }

View File

@ -1,8 +1,163 @@
pub struct Auth { use quick_xml::{
pub mechanism: String, events::{BytesStart, BytesText, Event},
pub sasl_data: Option<String>, name::QName,
};
use crate::error::SASLError;
use crate::JabberError;
use super::Element;
const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
#[derive(Debug)]
pub struct Auth<'e> {
pub mechanism: &'e str,
pub sasl_data: &'e str,
} }
pub struct Challenge { impl<'e> Auth<'e> {
pub sasl_data: String, fn event(&self) -> Event<'e> {
let mut start = BytesStart::new("auth");
start.push_attribute(("xmlns", XMLNS));
start.push_attribute(("mechanism", self.mechanism));
Event::Start(start)
}
fn children(&self) -> Option<Vec<Element<'e>>> {
let sasl = BytesText::from_escaped(self.sasl_data);
let sasl = Element {
event: Event::Text(sasl),
children: None,
};
Some(vec![sasl])
}
}
impl<'e> Into<Element<'e>> for Auth<'e> {
fn into(self) -> Element<'e> {
Element {
event: self.event(),
children: self.children(),
}
}
}
#[derive(Debug)]
pub struct Challenge {
pub sasl_data: Vec<u8>,
}
impl<'e> TryFrom<&Element<'e>> for Challenge {
type Error = JabberError;
fn try_from(element: &Element<'e>) -> Result<Challenge, Self::Error> {
if let Event::Start(start) = &element.event {
if start.name() == QName(b"challenge") {
let sasl_data: &Element<'_> = super::child(element)?;
if let Event::Text(sasl_data) = &sasl_data.event {
let s = sasl_data.clone();
let s = s.into_inner();
let s = s.to_vec();
return Ok(Challenge { sasl_data: s });
}
}
}
Err(SASLError::NoChallenge.into())
}
}
// impl<'e> TryFrom<Element<'e>> for Challenge {
// type Error = JabberError;
// fn try_from(element: Element<'e>) -> Result<Challenge, Self::Error> {
// if let Event::Start(start) = &element.event {
// if start.name() == QName(b"challenge") {
// println!("one");
// if let Some(children) = element.children.as_deref() {
// if children.len() == 1 {
// let sasl_data = children.first().unwrap();
// if let Event::Text(sasl_data) = &sasl_data.event {
// return Ok(Challenge {
// sasl_data: sasl_data.clone().into_inner().to_vec(),
// });
// } else {
// return Err(SASLError::NoChallenge.into());
// }
// } else {
// return Err(SASLError::NoChallenge.into());
// }
// } else {
// return Err(SASLError::NoChallenge.into());
// }
// }
// }
// Err(SASLError::NoChallenge.into())
// }
// }
#[derive(Debug)]
pub struct Response<'e> {
pub sasl_data: &'e str,
}
impl<'e> Response<'e> {
fn event(&self) -> Event<'e> {
let mut start = BytesStart::new("response");
start.push_attribute(("xmlns", XMLNS));
Event::Start(start)
}
fn children(&self) -> Option<Vec<Element<'e>>> {
let sasl = BytesText::from_escaped(self.sasl_data);
let sasl = Element {
event: Event::Text(sasl),
children: None,
};
Some(vec![sasl])
}
}
impl<'e> Into<Element<'e>> for Response<'e> {
fn into(self) -> Element<'e> {
Element {
event: self.event(),
children: self.children(),
}
}
}
#[derive(Debug)]
pub struct Success {
pub sasl_data: Option<Vec<u8>>,
}
impl<'e> TryFrom<&Element<'e>> for Success {
type Error = JabberError;
fn try_from(element: &Element<'e>) -> Result<Success, Self::Error> {
match &element.event {
Event::Start(start) => {
if start.name() == QName(b"success") {
match super::child(element) {
Ok(sasl_data) => {
if let Event::Text(sasl_data) = &sasl_data.event {
return Ok(Success {
sasl_data: Some(sasl_data.clone().into_inner().to_vec()),
});
}
}
Err(_) => return Ok(Success { sasl_data: None }),
};
}
}
Event::Empty(empty) => {
if empty.name() == QName(b"success") {
return Ok(Success { sasl_data: None });
}
}
_ => {}
}
Err(SASLError::NoSuccess.into())
}
} }

View File

@ -58,7 +58,7 @@ impl Stream {
} }
} }
fn build(&self) -> BytesStart { fn event(&self) -> Event<'static> {
let mut start = BytesStart::new("stream:stream"); let mut start = BytesStart::new("stream:stream");
if let Some(from) = &self.from { if let Some(from) = &self.from {
start.push_attribute(("from", from.to_string().as_str())); start.push_attribute(("from", from.to_string().as_str()));
@ -80,15 +80,15 @@ impl Stream {
XMLNS::Server => start.push_attribute(("xmlns", XMLNS::Server.into())), XMLNS::Server => start.push_attribute(("xmlns", XMLNS::Server.into())),
} }
start.push_attribute(("xmlns:stream", XMLNS_STREAM)); start.push_attribute(("xmlns:stream", XMLNS_STREAM));
start Event::Start(start)
} }
} }
impl<'e> Into<Element<'e>> for Stream { impl<'e> Into<Element<'e>> for Stream {
fn into(self) -> Element<'e> { fn into(self) -> Element<'e> {
Element { Element {
event: Event::Start(self.build().to_owned()), event: self.event(),
content: None, children: None,
} }
} }
} }
@ -153,17 +153,17 @@ impl<'e> TryFrom<Element<'e>> for Vec<StreamFeature> {
fn try_from(features_element: Element) -> Result<Self> { fn try_from(features_element: Element) -> Result<Self> {
let mut features = Vec::new(); let mut features = Vec::new();
if let Some(content) = features_element.content { if let Some(children) = features_element.children {
for feature_element in content { for feature_element in children {
match feature_element.event { match feature_element.event {
Event::Start(e) => match e.name() { Event::Start(e) => match e.name() {
QName(b"starttls") => features.push(StreamFeature::StartTls), QName(b"starttls") => features.push(StreamFeature::StartTls),
QName(b"mechanisms") => { QName(b"mechanisms") => {
let mut mechanisms = Vec::new(); let mut mechanisms = Vec::new();
if let Some(content) = feature_element.content { if let Some(children) = feature_element.children {
for mechanism_element in content { for mechanism_element in children {
if let Some(content) = mechanism_element.content { if let Some(children) = mechanism_element.children {
for mechanism_text in content { for mechanism_text in children {
match mechanism_text.event { match mechanism_text.event {
Event::Text(e) => mechanisms Event::Text(e) => mechanisms
.push(str::from_utf8(e.as_ref())?.to_owned()), .push(str::from_utf8(e.as_ref())?.to_owned()),