Compare commits

..

No commits in common. "b6593389069903cc4c85e40611296d8a240f718d" and "668270429f9b59f71c21daa7f92a28d422503bb7" have entirely different histories.

6 changed files with 51 additions and 427 deletions

View File

@ -12,7 +12,7 @@ async-trait = "0.1.68"
lazy_static = "1.4.0" lazy_static = "1.4.0"
nanoid = "0.4.0" nanoid = "0.4.0"
# TODO: remove unneeded features # TODO: remove unneeded features
rsasl = { version = "2.0.1", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] } rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }
tokio = { version = "1.28", features = ["full"] } tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
tracing = "0.1.40" tracing = "0.1.40"

View File

@ -1,18 +1,16 @@
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::str; use std::str;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use rsasl::config::SASLConfig;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector; use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls // TODO: use rustls
use tokio_native_tls::TlsStream; use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace}; use tracing::{debug, info, instrument, trace};
use crate::Error;
use crate::Jabber; use crate::Jabber;
use crate::Result; use crate::Result;
use crate::{Error, JID};
pub type Tls = TlsStream<TcpStream>; pub type Tls = TlsStream<TcpStream>;
pub type Unencrypted = TcpStream; pub type Unencrypted = TcpStream;
@ -25,7 +23,6 @@ pub enum Connection {
impl Connection { impl Connection {
#[instrument] #[instrument]
/// stream not started
pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
match self { match self {
Connection::Encrypted(j) => Ok(j), Connection::Encrypted(j) => Ok(j),
@ -39,20 +36,15 @@ impl Connection {
} }
} }
pub async fn connect_user(jid: impl AsRef<str>, password: String) -> Result<Self> { // pub async fn connect_user<J: TryInto<JID>>(jid: J, password: String) -> Result<Self> {
let jid: JID = JID::from_str(jid.as_ref())?; // let server = jid.domainpart.clone();
let server = jid.domainpart.clone(); // let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; // println!("auth: {:?}", auth);
println!("auth: {:?}", auth); // Self::connect(&server, jid.try_into()?, Some(auth)).await
Self::connect(&server, Some(jid), Some(auth)).await // }
}
#[instrument] #[instrument]
pub async fn connect( pub async fn connect(server: &str) -> Result<Self> {
server: &str,
jid: Option<JID>,
auth: Option<Arc<SASLConfig>>,
) -> Result<Self> {
info!("connecting to {}", server); info!("connecting to {}", server);
let sockets = Self::get_sockets(&server).await; let sockets = Self::get_sockets(&server).await;
debug!("discovered sockets: {:?}", sockets); debug!("discovered sockets: {:?}", sockets);
@ -65,8 +57,8 @@ impl Connection {
return Ok(Self::Encrypted(Jabber::new( return Ok(Self::Encrypted(Jabber::new(
readhalf, readhalf,
writehalf, writehalf,
jid, None,
auth, None,
server.to_owned(), server.to_owned(),
))); )));
} }
@ -78,8 +70,8 @@ impl Connection {
return Ok(Self::Unencrypted(Jabber::new( return Ok(Self::Unencrypted(Jabber::new(
readhalf, readhalf,
writehalf, writehalf,
jid, None,
auth, None,
server.to_owned(), server.to_owned(),
))); )));
} }
@ -188,12 +180,12 @@ mod tests {
#[test(tokio::test)] #[test(tokio::test)]
async fn connect() { async fn connect() {
Connection::connect("blos.sm", None, None).await.unwrap(); Connection::connect("blos.sm").await.unwrap();
} }
#[test(tokio::test)] #[test(tokio::test)]
async fn test_tls() { async fn test_tls() {
Connection::connect("blos.sm", None, None) Connection::connect("blos.sm")
.await .await
.unwrap() .unwrap()
.ensure_tls() .ensure_tls()

View File

@ -19,8 +19,6 @@ pub enum Error {
IDMismatch, IDMismatch,
BindError, BindError,
ParseError, ParseError,
Negotiation,
TlsRequired,
UnexpectedEnd, UnexpectedEnd,
UnexpectedElement, UnexpectedElement,
UnexpectedText, UnexpectedText,

View File

@ -1,26 +1,26 @@
use std::str; use std::str;
use std::sync::Arc; use std::sync::Arc;
use async_recursion::async_recursion;
use peanuts::element::{FromElement, IntoElement}; use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer}; use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use rsasl::prelude::SASLConfig;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::time::timeout;
use tokio_native_tls::native_tls::TlsConnector; use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, instrument, trace}; use tracing::{debug, info, trace};
use trust_dns_resolver::proto::rr::domain::IntoLabel; use trust_dns_resolver::proto::rr::domain::IntoLabel;
use crate::connection::{Tls, Unencrypted}; use crate::connection::{Tls, Unencrypted};
use crate::error::Error; use crate::error::Error;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls}; use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream}; use crate::stanza::stream::{Features, Stream};
use crate::stanza::XML_VERSION; use crate::stanza::XML_VERSION;
use crate::Result;
use crate::JID; use crate::JID;
use crate::{Connection, Result};
pub struct Jabber<S> { pub struct Jabber<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
reader: Reader<ReadHalf<S>>, reader: Reader<ReadHalf<S>>,
writer: Writer<WriteHalf<S>>, writer: Writer<WriteHalf<S>>,
jid: Option<JID>, jid: Option<JID>,
@ -54,93 +54,9 @@ where
impl<S> Jabber<S> impl<S> Jabber<S>
where where
S: AsyncRead + AsyncWrite + Unpin + Send, S: AsyncRead + AsyncWrite + Unpin + Send,
Jabber<S>: std::fmt::Debug,
{ {
pub async fn sasl( // pub async fn negotiate(self) -> Result<Jabber<S>> {}
&mut self,
mechanisms: Mechanisms,
sasl_config: Arc<SASLConfig>,
) -> Result<()> {
let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms {
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
}
debug!("{:?}", offered_mechs);
let mut session = sasl.start_suggested(&offered_mechs)?;
let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>> = None;
if !session.are_we_first() {
// if not first mention the mechanism then get challenge data
// mention mechanism
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: "=".to_string(),
};
self.writer.write_full(&auth).await?;
// get challenge data
let challenge: Challenge = self.reader.read().await?;
debug!("challenge: {:?}", challenge);
data = Some((*challenge).as_bytes().to_vec());
debug!("we didn't go first");
} else {
// if first, mention mechanism and send data
let mut sasl_data = Vec::new();
session.step64(None, &mut sasl_data).unwrap();
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: str::from_utf8(&sasl_data)?.to_string(),
};
debug!("{:?}", auth);
self.writer.write_full(&auth).await?;
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()),
}
debug!("we went first");
}
// stepping the authentication exchange to completion
if data != None {
debug!("data: {:?}", data);
let mut sasl_data = Vec::new();
while {
// decide if need to send more data over
let state = session
.step64(data.as_deref(), &mut sasl_data)
.expect("step errored!");
state.is_running()
} {
// While we aren't finished, receive more data from the other party
let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
debug!("response: {:?}", response);
self.writer.write_full(&response).await?;
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()),
}
}
}
Ok(())
}
pub async fn bind(&mut self) -> Result<()> {
todo!()
}
#[instrument]
pub async fn start_stream(&mut self) -> Result<()> { pub async fn start_stream(&mut self) -> Result<()> {
// client to server // client to server
@ -158,8 +74,6 @@ where
let decl = self.reader.read_prolog().await?; let decl = self.reader.read_prolog().await?;
// receive stream element and validate // receive stream element and validate
let text = str::from_utf8(self.reader.buffer.data()).unwrap();
debug!("data: {}", text);
let stream: Stream = self.reader.read_start().await?; let stream: Stream = self.reader.read_start().await?;
debug!("got stream: {:?}", stream); debug!("got stream: {:?}", stream);
if let Some(from) = stream.from { if let Some(from) = stream.from {
@ -181,87 +95,6 @@ where
} }
} }
impl Jabber<Unencrypted> {
pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> {
self.start_stream().await?;
// TODO: timeout
let features = self.get_features().await?.features;
if let Some(Feature::StartTls(_)) = features
.iter()
.find(|feature| matches!(feature, Feature::StartTls(_s)))
{
let jabber = self.starttls().await?;
let jabber = jabber.negotiate().await?;
return Ok(jabber);
} else {
// TODO: better error
return Err(Error::TlsRequired);
}
}
#[async_recursion]
pub async fn negotiate_tls_optional(mut self) -> Result<Connection> {
self.start_stream().await?;
// TODO: timeout
let features = self.get_features().await?.features;
if let Some(Feature::StartTls(_)) = features
.iter()
.find(|feature| matches!(feature, Feature::StartTls(_s)))
{
let jabber = self.starttls().await?;
let jabber = jabber.negotiate().await?;
return Ok(Connection::Encrypted(jabber));
} else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
self.auth.clone(),
features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_))),
) {
self.sasl(mechanisms.clone(), sasl_config).await?;
let jabber = self.negotiate_tls_optional().await?;
Ok(jabber)
} else if let Some(Feature::Bind) = features
.iter()
.find(|feature| matches!(feature, Feature::Bind))
{
self.bind().await?;
Ok(Connection::Unencrypted(self))
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
}
impl Jabber<Tls> {
#[async_recursion]
pub async fn negotiate(mut self) -> Result<Jabber<Tls>> {
self.start_stream().await?;
let features = self.get_features().await?.features;
if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
self.auth.clone(),
features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_))),
) {
// TODO: avoid clone
self.sasl(mechanisms.clone(), sasl_config).await?;
let jabber = self.negotiate().await?;
Ok(jabber)
} else if let Some(Feature::Bind) = features
.iter()
.find(|feature| matches!(feature, Feature::Bind))
{
self.bind().await?;
Ok(self)
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
}
impl Jabber<Unencrypted> { impl Jabber<Unencrypted> {
pub async fn starttls(mut self) -> Result<Jabber<Tls>> { pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
self.writer self.writer
@ -276,13 +109,14 @@ impl Jabber<Unencrypted> {
.await .await
{ {
let (read, write) = tokio::io::split(tlsstream); let (read, write) = tokio::io::split(tlsstream);
let client = Jabber::new( let mut client = Jabber::new(
read, read,
write, write,
self.jid.to_owned(), self.jid.to_owned(),
self.auth.to_owned(), self.auth.to_owned(),
self.server.to_owned(), self.server.to_owned(),
); );
client.start_stream().await?;
return Ok(client); return Ok(client);
} else { } else {
return Err(Error::Connection); return Err(Error::Connection);
@ -320,47 +154,10 @@ mod tests {
#[test(tokio::test)] #[test(tokio::test)]
async fn start_stream() { async fn start_stream() {
let connection = Connection::connect("blos.sm", None, None).await.unwrap(); let connection = Connection::connect("blos.sm").await.unwrap();
match connection { match connection {
Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
} }
} }
#[test(tokio::test)]
async fn sasl() {
let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
.await
.unwrap()
.ensure_tls()
.await
.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text);
jabber.start_stream().await.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text);
jabber.reader.read_buf().await.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text);
let features = jabber.get_features().await.unwrap();
let (sasl_config, feature) = (
jabber.auth.clone().unwrap(),
features
.features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_)))
.unwrap(),
);
match feature {
Feature::StartTls(_start_tls) => todo!(),
Feature::Sasl(mechanisms) => {
jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
}
Feature::Bind => todo!(),
Feature::Unknown => todo!(),
}
}
} }

View File

@ -1,170 +1 @@
use std::ops::Deref;
use peanuts::{
element::{FromElement, IntoElement},
DeserializeError, Element,
};
use tracing::debug;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
#[derive(Debug, Clone)]
pub struct Mechanisms {
pub mechanisms: Vec<String>,
}
impl FromElement for Mechanisms {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("mechanisms")?;
element.check_namespace(XMLNS)?;
debug!("getting mechanisms");
let mechanisms: Vec<Mechanism> = element.pop_children()?;
debug!("gottting mechanisms");
let mechanisms = mechanisms
.into_iter()
.map(|Mechanism(mechanism)| mechanism)
.collect();
debug!("gottting mechanisms");
Ok(Mechanisms { mechanisms })
}
}
impl IntoElement for Mechanisms {
fn builder(&self) -> peanuts::element::ElementBuilder {
Element::builder("mechanisms", Some(XMLNS)).push_children(
self.mechanisms
.iter()
.map(|mechanism| Mechanism(mechanism.to_string()))
.collect(),
)
}
}
pub struct Mechanism(String);
impl FromElement for Mechanism {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("mechanism")?;
element.check_namespace(XMLNS)?;
let mechanism = element.pop_value()?;
Ok(Mechanism(mechanism))
}
}
impl IntoElement for Mechanism {
fn builder(&self) -> peanuts::element::ElementBuilder {
Element::builder("mechanism", Some(XMLNS)).push_text(self.0.clone())
}
}
impl Deref for Mechanism {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub struct Auth {
pub mechanism: String,
pub sasl_data: String,
}
impl IntoElement for Auth {
fn builder(&self) -> peanuts::element::ElementBuilder {
Element::builder("auth", Some(XMLNS))
.push_attribute("mechanism", self.mechanism.clone())
.push_text(self.sasl_data.clone())
}
}
#[derive(Debug)]
pub struct Challenge(String);
impl Deref for Challenge {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl FromElement for Challenge {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("challenge")?;
element.check_namespace(XMLNS)?;
let sasl_data = element.value()?;
Ok(Challenge(sasl_data))
}
}
#[derive(Debug)]
pub struct Success(String);
impl Deref for Success {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl FromElement for Success {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("success")?;
element.check_namespace(XMLNS)?;
let sasl_data = element.value()?;
Ok(Success(sasl_data))
}
}
#[derive(Debug)]
pub enum ServerResponse {
Challenge(Challenge),
Success(Success),
}
impl FromElement for ServerResponse {
fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
match element.identify() {
(Some(XMLNS), "challenge") => {
Ok(ServerResponse::Challenge(Challenge::from_element(element)?))
}
(Some(XMLNS), "success") => {
Ok(ServerResponse::Success(Success::from_element(element)?))
}
_ => Err(DeserializeError::UnexpectedElement(element)),
}
}
}
#[derive(Debug)]
pub struct Response(String);
impl Response {
pub fn new(response: String) -> Self {
Self(response)
}
}
impl Deref for Response {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl IntoElement for Response {
fn builder(&self) -> peanuts::element::ElementBuilder {
Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone())
}
}

View File

@ -3,11 +3,9 @@ use std::collections::{HashMap, HashSet};
use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration}; use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration};
use peanuts::XML_NS; use peanuts::XML_NS;
use peanuts::{element::Name, Element}; use peanuts::{element::Name, Element};
use tracing::debug;
use crate::{Error, JID}; use crate::{Error, JID};
use super::sasl::{self, Mechanisms};
use super::starttls::{self, StartTls}; use super::starttls::{self, StartTls};
pub const XMLNS: &str = "http://etherx.jabber.org/streams"; pub const XMLNS: &str = "http://etherx.jabber.org/streams";
@ -94,12 +92,32 @@ impl<'s> Stream {
#[derive(Debug)] #[derive(Debug)]
pub struct Features { pub struct Features {
pub features: Vec<Feature>, features: Vec<Feature>,
} }
impl IntoElement for Features { impl IntoElement for Features {
fn builder(&self) -> ElementBuilder { fn builder(&self) -> ElementBuilder {
Element::builder("features", Some(XMLNS)).push_children(self.features.clone()) Element::builder("features", Some(XMLNS)).push_children(self.features.clone())
// let mut content = Vec::new();
// for feature in &self.features {
// match feature {
// Feature::StartTls(start_tls) => {
// content.push(Content::Element(start_tls.into_element()))
// }
// Feature::Sasl => {}
// Feature::Bind => {}
// Feature::Unknown => {}
// }
// }
// Element {
// name: Name {
// namespace: Some(XMLNS.to_string()),
// local_name: "features".to_string(),
// },
// namespace_declaration_overrides: HashSet::new(),
// attributes: HashMap::new(),
// content,
// }
} }
} }
@ -110,9 +128,7 @@ impl FromElement for Features {
element.check_namespace(XMLNS)?; element.check_namespace(XMLNS)?;
element.check_name("features")?; element.check_name("features")?;
debug!("got features stanza");
let features = element.children()?; let features = element.children()?;
debug!("got features period");
Ok(Features { features }) Ok(Features { features })
} }
@ -121,7 +137,7 @@ impl FromElement for Features {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Feature { pub enum Feature {
StartTls(StartTls), StartTls(StartTls),
Sasl(Mechanisms), Sasl,
Bind, Bind,
Unknown, Unknown,
} }
@ -130,7 +146,7 @@ impl IntoElement for Feature {
fn builder(&self) -> ElementBuilder { fn builder(&self) -> ElementBuilder {
match self { match self {
Feature::StartTls(start_tls) => start_tls.builder(), Feature::StartTls(start_tls) => start_tls.builder(),
Feature::Sasl(mechanisms) => mechanisms.builder(), Feature::Sasl => todo!(),
Feature::Bind => todo!(), Feature::Bind => todo!(),
Feature::Unknown => todo!(), Feature::Unknown => todo!(),
} }
@ -139,21 +155,11 @@ impl IntoElement for Feature {
impl FromElement for Feature { impl FromElement for Feature {
fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
let identity = element.identify();
debug!("identity: {:?}", identity);
match element.identify() { match element.identify() {
(Some(starttls::XMLNS), "starttls") => { (Some(starttls::XMLNS), "starttls") => {
debug!("identified starttls");
Ok(Feature::StartTls(StartTls::from_element(element)?)) Ok(Feature::StartTls(StartTls::from_element(element)?))
} }
(Some(sasl::XMLNS), "mechanisms") => { _ => Ok(Feature::Unknown),
debug!("identified mechanisms");
Ok(Feature::Sasl(Mechanisms::from_element(element)?))
}
_ => {
debug!("identified unknown feature");
Ok(Feature::Unknown)
}
} }
} }
} }