diff --git a/Cargo.toml b/Cargo.toml index f136e90..326e45e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ async-trait = "0.1.68" lazy_static = "1.4.0" nanoid = "0.4.0" # TODO: remove unneeded features -rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } +rsasl = { version = "2.0.1", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] } tokio = { version = "1.28", features = ["full"] } tokio-native-tls = "0.3.1" tracing = "0.1.40" diff --git a/src/connection.rs b/src/connection.rs index 65e9383..9e485d3 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,16 +1,18 @@ use std::net::{IpAddr, SocketAddr}; use std::str; use std::str::FromStr; +use std::sync::Arc; +use rsasl::config::SASLConfig; use tokio::net::TcpStream; use tokio_native_tls::native_tls::TlsConnector; // TODO: use rustls use tokio_native_tls::TlsStream; use tracing::{debug, info, instrument, trace}; -use crate::Error; use crate::Jabber; use crate::Result; +use crate::{Error, JID}; pub type Tls = TlsStream; pub type Unencrypted = TcpStream; @@ -37,15 +39,20 @@ impl Connection { } } - // pub async fn connect_user>(jid: J, password: String) -> Result { - // let server = jid.domainpart.clone(); - // let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; - // println!("auth: {:?}", auth); - // Self::connect(&server, jid.try_into()?, Some(auth)).await - // } + pub async fn connect_user(jid: impl AsRef, password: String) -> Result { + let jid: JID = JID::from_str(jid.as_ref())?; + let server = jid.domainpart.clone(); + let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; + println!("auth: {:?}", auth); + Self::connect(&server, Some(jid), Some(auth)).await + } #[instrument] - pub async fn connect(server: &str) -> Result { + pub async fn connect( + server: &str, + jid: Option, + auth: Option>, + ) -> Result { info!("connecting to {}", server); let sockets = Self::get_sockets(&server).await; debug!("discovered sockets: {:?}", sockets); @@ -58,8 +65,8 @@ impl Connection { return Ok(Self::Encrypted(Jabber::new( readhalf, writehalf, - None, - None, + jid, + auth, server.to_owned(), ))); } @@ -71,8 +78,8 @@ impl Connection { return Ok(Self::Unencrypted(Jabber::new( readhalf, writehalf, - None, - None, + jid, + auth, server.to_owned(), ))); } @@ -181,12 +188,12 @@ mod tests { #[test(tokio::test)] async fn connect() { - Connection::connect("blos.sm").await.unwrap(); + Connection::connect("blos.sm", None, None).await.unwrap(); } #[test(tokio::test)] async fn test_tls() { - Connection::connect("blos.sm") + Connection::connect("blos.sm", None, None) .await .unwrap() .ensure_tls() diff --git a/src/error.rs b/src/error.rs index c7c867c..8ee9077 100644 --- a/src/error.rs +++ b/src/error.rs @@ -19,6 +19,8 @@ pub enum Error { IDMismatch, BindError, ParseError, + Negotiation, + TlsRequired, UnexpectedEnd, UnexpectedElement, UnexpectedText, diff --git a/src/jabber.rs b/src/jabber.rs index a56c65c..9e7f9d8 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,26 +1,26 @@ use std::str; use std::sync::Arc; +use async_recursion::async_recursion; use peanuts::element::{FromElement, IntoElement}; use peanuts::{Reader, Writer}; -use rsasl::prelude::SASLConfig; +use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio::time::timeout; use tokio_native_tls::native_tls::TlsConnector; use tracing::{debug, info, instrument, trace}; use trust_dns_resolver::proto::rr::domain::IntoLabel; use crate::connection::{Tls, Unencrypted}; use crate::error::Error; +use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use crate::stanza::starttls::{Proceed, StartTls}; -use crate::stanza::stream::{Features, Stream}; +use crate::stanza::stream::{Feature, Features, Stream}; use crate::stanza::XML_VERSION; -use crate::Result; use crate::JID; +use crate::{Connection, Result}; -pub struct Jabber -where - S: AsyncRead + AsyncWrite + Unpin, -{ +pub struct Jabber { reader: Reader>, writer: Writer>, jid: Option, @@ -56,7 +56,89 @@ where S: AsyncRead + AsyncWrite + Unpin + Send, Jabber: std::fmt::Debug, { - // pub async fn negotiate(self) -> Result> {} + pub async fn sasl( + &mut self, + mechanisms: Mechanisms, + sasl_config: Arc, + ) -> Result<()> { + let sasl = SASLClient::new(sasl_config); + let mut offered_mechs: Vec<&Mechname> = Vec::new(); + for mechanism in &mechanisms.mechanisms { + offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) + } + debug!("{:?}", offered_mechs); + let mut session = sasl.start_suggested(&offered_mechs)?; + let selected_mechanism = session.get_mechname().as_str().to_owned(); + debug!("selected mech: {:?}", selected_mechanism); + let mut data: Option> = None; + + if !session.are_we_first() { + // if not first mention the mechanism then get challenge data + // mention mechanism + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: "=".to_string(), + }; + self.writer.write_full(&auth).await?; + // get challenge data + let challenge: Challenge = self.reader.read().await?; + debug!("challenge: {:?}", challenge); + data = Some((*challenge).as_bytes().to_vec()); + debug!("we didn't go first"); + } else { + // if first, mention mechanism and send data + let mut sasl_data = Vec::new(); + session.step64(None, &mut sasl_data).unwrap(); + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: str::from_utf8(&sasl_data)?.to_string(), + }; + debug!("{:?}", auth); + self.writer.write_full(&auth).await?; + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), + } + debug!("we went first"); + } + + // stepping the authentication exchange to completion + if data != None { + debug!("data: {:?}", data); + let mut sasl_data = Vec::new(); + while { + // decide if need to send more data over + let state = session + .step64(data.as_deref(), &mut sasl_data) + .expect("step errored!"); + state.is_running() + } { + // While we aren't finished, receive more data from the other party + let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); + debug!("response: {:?}", response); + self.writer.write_full(&response).await?; + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), + } + } + } + Ok(()) + } + + pub async fn bind(&mut self) -> Result<()> { + todo!() + } #[instrument] pub async fn start_stream(&mut self) -> Result<()> { @@ -76,6 +158,8 @@ where let decl = self.reader.read_prolog().await?; // receive stream element and validate + let text = str::from_utf8(self.reader.buffer.data()).unwrap(); + debug!("data: {}", text); let stream: Stream = self.reader.read_start().await?; debug!("got stream: {:?}", stream); if let Some(from) = stream.from { @@ -97,6 +181,87 @@ where } } +impl Jabber { + pub async fn negotiate(mut self) -> Result> { + self.start_stream().await?; + // TODO: timeout + let features = self.get_features().await?.features; + if let Some(Feature::StartTls(_)) = features + .iter() + .find(|feature| matches!(feature, Feature::StartTls(_s))) + { + let jabber = self.starttls().await?; + let jabber = jabber.negotiate().await?; + return Ok(jabber); + } else { + // TODO: better error + return Err(Error::TlsRequired); + } + } + + #[async_recursion] + pub async fn negotiate_tls_optional(mut self) -> Result { + self.start_stream().await?; + // TODO: timeout + let features = self.get_features().await?.features; + if let Some(Feature::StartTls(_)) = features + .iter() + .find(|feature| matches!(feature, Feature::StartTls(_s))) + { + let jabber = self.starttls().await?; + let jabber = jabber.negotiate().await?; + return Ok(Connection::Encrypted(jabber)); + } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( + self.auth.clone(), + features + .iter() + .find(|feature| matches!(feature, Feature::Sasl(_))), + ) { + self.sasl(mechanisms.clone(), sasl_config).await?; + let jabber = self.negotiate_tls_optional().await?; + Ok(jabber) + } else if let Some(Feature::Bind) = features + .iter() + .find(|feature| matches!(feature, Feature::Bind)) + { + self.bind().await?; + Ok(Connection::Unencrypted(self)) + } else { + // TODO: better error + return Err(Error::Negotiation); + } + } +} + +impl Jabber { + #[async_recursion] + pub async fn negotiate(mut self) -> Result> { + self.start_stream().await?; + let features = self.get_features().await?.features; + + if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( + self.auth.clone(), + features + .iter() + .find(|feature| matches!(feature, Feature::Sasl(_))), + ) { + // TODO: avoid clone + self.sasl(mechanisms.clone(), sasl_config).await?; + let jabber = self.negotiate().await?; + Ok(jabber) + } else if let Some(Feature::Bind) = features + .iter() + .find(|feature| matches!(feature, Feature::Bind)) + { + self.bind().await?; + Ok(self) + } else { + // TODO: better error + return Err(Error::Negotiation); + } + } +} + impl Jabber { pub async fn starttls(mut self) -> Result> { self.writer @@ -155,10 +320,47 @@ mod tests { #[test(tokio::test)] async fn start_stream() { - let connection = Connection::connect("blos.sm").await.unwrap(); + let connection = Connection::connect("blos.sm", None, None).await.unwrap(); match connection { Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), } } + + #[test(tokio::test)] + async fn sasl() { + let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + .await + .unwrap() + .ensure_tls() + .await + .unwrap(); + let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + println!("data: {}", text); + jabber.start_stream().await.unwrap(); + + let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + println!("data: {}", text); + jabber.reader.read_buf().await.unwrap(); + let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + println!("data: {}", text); + + let features = jabber.get_features().await.unwrap(); + let (sasl_config, feature) = ( + jabber.auth.clone().unwrap(), + features + .features + .iter() + .find(|feature| matches!(feature, Feature::Sasl(_))) + .unwrap(), + ); + match feature { + Feature::StartTls(_start_tls) => todo!(), + Feature::Sasl(mechanisms) => { + jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); + } + Feature::Bind => todo!(), + Feature::Unknown => todo!(), + } + } } diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 8b13789..6ac4fc9 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -1 +1,170 @@ +use std::ops::Deref; +use peanuts::{ + element::{FromElement, IntoElement}, + DeserializeError, Element, +}; +use tracing::debug; + +pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; + +#[derive(Debug, Clone)] +pub struct Mechanisms { + pub mechanisms: Vec, +} + +impl FromElement for Mechanisms { + fn from_element(mut element: Element) -> peanuts::element::DeserializeResult { + element.check_name("mechanisms")?; + element.check_namespace(XMLNS)?; + debug!("getting mechanisms"); + let mechanisms: Vec = element.pop_children()?; + debug!("gottting mechanisms"); + let mechanisms = mechanisms + .into_iter() + .map(|Mechanism(mechanism)| mechanism) + .collect(); + debug!("gottting mechanisms"); + + Ok(Mechanisms { mechanisms }) + } +} + +impl IntoElement for Mechanisms { + fn builder(&self) -> peanuts::element::ElementBuilder { + Element::builder("mechanisms", Some(XMLNS)).push_children( + self.mechanisms + .iter() + .map(|mechanism| Mechanism(mechanism.to_string())) + .collect(), + ) + } +} + +pub struct Mechanism(String); + +impl FromElement for Mechanism { + fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult { + element.check_name("mechanism")?; + element.check_namespace(XMLNS)?; + + let mechanism = element.pop_value()?; + + Ok(Mechanism(mechanism)) + } +} + +impl IntoElement for Mechanism { + fn builder(&self) -> peanuts::element::ElementBuilder { + Element::builder("mechanism", Some(XMLNS)).push_text(self.0.clone()) + } +} + +impl Deref for Mechanism { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug)] +pub struct Auth { + pub mechanism: String, + pub sasl_data: String, +} + +impl IntoElement for Auth { + fn builder(&self) -> peanuts::element::ElementBuilder { + Element::builder("auth", Some(XMLNS)) + .push_attribute("mechanism", self.mechanism.clone()) + .push_text(self.sasl_data.clone()) + } +} + +#[derive(Debug)] +pub struct Challenge(String); + +impl Deref for Challenge { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl FromElement for Challenge { + fn from_element(mut element: Element) -> peanuts::element::DeserializeResult { + element.check_name("challenge")?; + element.check_namespace(XMLNS)?; + + let sasl_data = element.value()?; + + Ok(Challenge(sasl_data)) + } +} + +#[derive(Debug)] +pub struct Success(String); + +impl Deref for Success { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl FromElement for Success { + fn from_element(mut element: Element) -> peanuts::element::DeserializeResult { + element.check_name("success")?; + element.check_namespace(XMLNS)?; + + let sasl_data = element.value()?; + + Ok(Success(sasl_data)) + } +} + +#[derive(Debug)] +pub enum ServerResponse { + Challenge(Challenge), + Success(Success), +} + +impl FromElement for ServerResponse { + fn from_element(element: Element) -> peanuts::element::DeserializeResult { + match element.identify() { + (Some(XMLNS), "challenge") => { + Ok(ServerResponse::Challenge(Challenge::from_element(element)?)) + } + (Some(XMLNS), "success") => { + Ok(ServerResponse::Success(Success::from_element(element)?)) + } + _ => Err(DeserializeError::UnexpectedElement(element)), + } + } +} + +#[derive(Debug)] +pub struct Response(String); + +impl Response { + pub fn new(response: String) -> Self { + Self(response) + } +} + +impl Deref for Response { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl IntoElement for Response { + fn builder(&self) -> peanuts::element::ElementBuilder { + Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone()) + } +} diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 40f6ba0..fecace5 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -3,9 +3,11 @@ use std::collections::{HashMap, HashSet}; use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration}; use peanuts::XML_NS; use peanuts::{element::Name, Element}; +use tracing::debug; use crate::{Error, JID}; +use super::sasl::{self, Mechanisms}; use super::starttls::{self, StartTls}; pub const XMLNS: &str = "http://etherx.jabber.org/streams"; @@ -92,32 +94,12 @@ impl<'s> Stream { #[derive(Debug)] pub struct Features { - features: Vec, + pub features: Vec, } impl IntoElement for Features { fn builder(&self) -> ElementBuilder { Element::builder("features", Some(XMLNS)).push_children(self.features.clone()) - // let mut content = Vec::new(); - // for feature in &self.features { - // match feature { - // Feature::StartTls(start_tls) => { - // content.push(Content::Element(start_tls.into_element())) - // } - // Feature::Sasl => {} - // Feature::Bind => {} - // Feature::Unknown => {} - // } - // } - // Element { - // name: Name { - // namespace: Some(XMLNS.to_string()), - // local_name: "features".to_string(), - // }, - // namespace_declaration_overrides: HashSet::new(), - // attributes: HashMap::new(), - // content, - // } } } @@ -128,7 +110,9 @@ impl FromElement for Features { element.check_namespace(XMLNS)?; element.check_name("features")?; + debug!("got features stanza"); let features = element.children()?; + debug!("got features period"); Ok(Features { features }) } @@ -137,7 +121,7 @@ impl FromElement for Features { #[derive(Debug, Clone)] pub enum Feature { StartTls(StartTls), - Sasl, + Sasl(Mechanisms), Bind, Unknown, } @@ -146,7 +130,7 @@ impl IntoElement for Feature { fn builder(&self) -> ElementBuilder { match self { Feature::StartTls(start_tls) => start_tls.builder(), - Feature::Sasl => todo!(), + Feature::Sasl(mechanisms) => mechanisms.builder(), Feature::Bind => todo!(), Feature::Unknown => todo!(), } @@ -155,11 +139,21 @@ impl IntoElement for Feature { impl FromElement for Feature { fn from_element(element: Element) -> peanuts::element::DeserializeResult { + let identity = element.identify(); + debug!("identity: {:?}", identity); match element.identify() { (Some(starttls::XMLNS), "starttls") => { + debug!("identified starttls"); Ok(Feature::StartTls(StartTls::from_element(element)?)) } - _ => Ok(Feature::Unknown), + (Some(sasl::XMLNS), "mechanisms") => { + debug!("identified mechanisms"); + Ok(Feature::Sasl(Mechanisms::from_element(element)?)) + } + _ => { + debug!("identified unknown feature"); + Ok(Feature::Unknown) + } } } }