This commit is contained in:
cel 🌸 2023-07-04 21:27:15 +01:00
parent c0a7116eef
commit 143a0365d0
Signed by: cel
GPG Key ID: 48E29AF13B5F1349
10 changed files with 329 additions and 44 deletions

View File

@ -8,7 +8,9 @@ edition = "2021"
[dependencies]
async-trait = "0.1.68"
quick-xml = { version = "0.29.0", features = ["async-tokio", "serialize"] }
quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio", "serialize"] }
# TODO: remove unneeded features
rsasl = { version = "2", default_features = false, features = ["provider_base64", "plain", "config_builder"] }
serde = { version = "1.0.164", features = ["derive"] }
tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"

View File

@ -1,24 +1,35 @@
use std::str;
use quick_xml::{
de::Deserializer,
events::{BytesDecl, BytesStart, Event},
name::QName,
se::Serializer,
Reader, Writer,
};
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use rsasl::prelude::{Mechname, SASLClient};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;
use crate::stanza::{
sasl::{Auth, Challenge, Mechanisms},
stream::{StreamFeature, StreamFeatures},
};
use crate::Jabber;
use crate::Result;
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
writer: WriteHalf<TlsStream<TcpStream>>,
jabber: &'j mut Jabber<'j>,
}
impl<'j> JabberClient<'j> {
pub fn new(
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
writer: WriteHalf<TlsStream<TcpStream>>,
jabber: &'j mut Jabber<'j>,
) -> Self {
Self {
@ -37,13 +48,9 @@ impl<'j> JabberClient<'j> {
stream_element.push_attribute(("xml:lang", "en"));
stream_element.push_attribute(("xmlns", "jabber:client"));
stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
self.writer
.write_event_async(Event::Decl(declaration))
.await;
self.writer
.write_event_async(Event::Start(stream_element))
.await
.unwrap();
let mut writer = Writer::new(&mut self.writer);
writer.write_event_async(Event::Decl(declaration)).await;
writer.write_event_async(Event::Start(stream_element)).await;
let mut buf = Vec::new();
loop {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
@ -56,4 +63,166 @@ impl<'j> JabberClient<'j> {
}
Ok(())
}
pub async fn get_node<'a>(&mut self) -> Result<String> {
let mut buf = Vec::new();
let mut txt = Vec::new();
let mut qname_set = false;
let mut qname: Option<Vec<u8>> = None;
loop {
match self.reader.read_event_into_async(&mut buf).await? {
Event::Start(e) => {
if !qname_set {
qname = Some(e.name().into_inner().to_owned());
qname_set = true;
}
txt.push(b'<');
txt = txt
.into_iter()
.chain(buf.to_owned())
.chain(vec![b'>'])
.collect();
}
Event::End(e) => {
let mut end = false;
if e.name() == QName(qname.as_deref().unwrap()) {
end = true;
}
txt.push(b'<');
txt = txt
.into_iter()
.chain(buf.to_owned())
.chain(vec![b'>'])
.collect();
if end {
break;
}
}
Event::Text(_e) => {
txt = txt.into_iter().chain(buf.to_owned()).collect();
}
_ => {
txt.push(b'<');
txt = txt
.into_iter()
.chain(buf.to_owned())
.chain(vec![b'>'])
.collect();
}
}
buf.clear();
}
println!("{:?}", txt);
let decoded = str::from_utf8(&txt)?.to_owned();
println!("{:?}", decoded);
Ok(decoded)
}
pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
let node = self.get_node().await?;
let mut deserializer = Deserializer::from_str(&node);
let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
println!("{:?}", features);
Ok(features.features)
}
pub async fn negotiate(&mut self) -> Result<()> {
loop {
println!("loop");
let features = &self.get_features().await?;
println!("{:?}", features);
match &features[0] {
StreamFeature::Sasl(sasl) => {
println!("{:?}", sasl);
self.sasl(&sasl).await?;
}
StreamFeature::Bind => todo!(),
x => println!("{:?}", x),
}
}
}
pub async fn sasl(&mut self, mechanisms: &Mechanisms) -> Result<()> {
println!("{:?}", mechanisms);
let sasl = SASLClient::new(self.jabber.auth.clone());
let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms {
offered_mechs.push(Mechname::parse(&mechanism.mechanism.as_bytes())?)
}
println!("{:?}", offered_mechs);
let mut session = sasl.start_suggested(&offered_mechs)?;
let selected_mechanism = session.get_mechname().as_str().to_owned();
println!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>> = None;
if !session.are_we_first() {
// if not first mention the mechanism then get challenge data
// mention mechanism
let auth = Auth {
ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
mechanism: selected_mechanism.clone(),
sasl_data: Some("=".to_owned()),
};
let mut buffer = String::new();
let ser = Serializer::new(&mut buffer);
auth.serialize(ser).unwrap();
self.writer.write_all(buffer.as_bytes());
// get challenge data
let node = self.get_node().await?;
let mut deserializer = Deserializer::from_str(&node);
let challenge = Challenge::deserialize(&mut deserializer).unwrap();
println!("challenge: {:?}", challenge);
data = Some(challenge.sasl_data.as_bytes().to_owned());
println!("we didn't go first");
} else {
// if first, mention mechanism and send data
let mut sasl_data = Vec::new();
session.step64(None, &mut sasl_data).unwrap();
let auth = Auth {
ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
mechanism: selected_mechanism.clone(),
sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
};
let mut buffer = String::new();
let ser = Serializer::new(&mut buffer);
auth.serialize(ser).unwrap();
println!("node: {:?}", buffer);
self.writer.write_all(buffer.as_bytes()).await;
println!("we went first");
// get challenge data
// TODO: check if needed
// let node = self.get_node().await?;
// println!("node: {:?}", node);
// let mut deserializer = Deserializer::from_str(&node);
// let challenge = Challenge::deserialize(&mut deserializer).unwrap();
// println!("challenge: {:?}", challenge);
// data = Some(challenge.sasl_data.as_bytes().to_owned());
}
// stepping the authentication exchange to completion
let mut sasl_data = Vec::new();
while {
// decide if need to send more data over
let state = session
.step64(data.as_deref(), &mut sasl_data)
.expect("step errored!");
state.is_running()
} {
// While we aren't finished, receive more data from the other party
let auth = Auth {
ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
mechanism: selected_mechanism.clone(),
sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
};
let mut buffer = String::new();
let ser = Serializer::new(&mut buffer);
auth.serialize(ser).unwrap();
self.writer.write_all(buffer.as_bytes());
let node = self.get_node().await?;
let mut deserializer = Deserializer::from_str(&node);
let challenge = Challenge::deserialize(&mut deserializer).unwrap();
data = Some(challenge.sasl_data.as_bytes().to_owned());
}
self.start_stream().await?;
Ok(())
}
}

View File

@ -115,14 +115,12 @@ impl<'j> JabberClient<'j> {
.connect(&self.jabber.server, stream)
.await
{
let (read, write) = tokio::io::split(tlsstream);
let (read, writer) = tokio::io::split(tlsstream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(super::encrypted::JabberClient::new(
reader,
writer,
self.jabber,
));
let mut client =
super::encrypted::JabberClient::new(reader, writer, self.jabber);
client.start_stream().await?;
return Ok(client);
}
}
QName(_) => return Err(JabberError::TlsNegotiation),

View File

@ -1,7 +1,44 @@
use std::str::Utf8Error;
use rsasl::mechname::MechanismNameError;
#[derive(Debug)]
pub enum JabberError {
ConnectionError,
Connection,
BadStream,
StartTlsUnavailable,
TlsNegotiation,
Utf8Decode,
XML(quick_xml::Error),
SASL(SASLError),
}
#[derive(Debug)]
pub enum SASLError {
SASL(rsasl::prelude::SASLError),
MechanismName(MechanismNameError),
}
impl From<rsasl::prelude::SASLError> for JabberError {
fn from(e: rsasl::prelude::SASLError) -> Self {
Self::SASL(SASLError::SASL(e))
}
}
impl From<MechanismNameError> for JabberError {
fn from(value: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(value))
}
}
impl From<Utf8Error> for JabberError {
fn from(e: Utf8Error) -> Self {
Self::Utf8Decode
}
}
impl From<quick_xml::Error> for JabberError {
fn from(e: quick_xml::Error) -> Self {
Self::XML(e)
}
}

View File

@ -1,33 +1,44 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use quick_xml::{Reader, Writer};
use rsasl::prelude::SASLConfig;
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
use crate::client;
use crate::client::JabberClientType;
use crate::jid::JID;
use crate::{client, JabberClient};
use crate::{JabberError, Result};
pub struct Jabber<'j> {
pub jid: JID,
pub password: String,
pub auth: Arc<SASLConfig>,
pub server: String,
_marker: PhantomData<&'j ()>,
}
impl<'j> Jabber<'j> {
pub fn new(jid: JID, password: String) -> Self {
pub fn new(jid: JID, password: String) -> Result<Self> {
let server = jid.domainpart.clone();
Self {
let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?;
println!("auth: {:?}", auth);
Ok(Self {
jid,
password,
auth,
server,
_marker: PhantomData,
}
})
}
pub async fn login(&'j mut self) -> Result<JabberClient<'j>> {
let mut client = self.connect().await?.ensure_tls().await?;
println!("negotiation");
client.negotiate().await?;
Ok(client)
}
async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> {
@ -106,9 +117,8 @@ impl<'j> Jabber<'j> {
.connect(&self.server, socket)
.await
{
let (read, write) = tokio::io::split(stream);
let (read, writer) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(JabberClientType::Encrypted(
client::encrypted::JabberClient::new(reader, writer, self),
));
@ -126,6 +136,6 @@ impl<'j> Jabber<'j> {
}
}
}
Err(JabberError::ConnectionError)
Err(JabberError::Connection)
}
}

View File

@ -8,8 +8,13 @@ pub struct JID {
pub resourcepart: Option<String>,
}
pub enum JIDError {
NoResourcePart,
ParseError(ParseError),
}
#[derive(Debug)]
pub enum JIDParseError {
pub enum ParseError {
Empty,
Malformed,
}
@ -26,15 +31,31 @@ impl JID {
resourcepart,
}
}
pub fn as_bare(&self) -> Self {
Self {
localpart: self.localpart.clone(),
domainpart: self.domainpart.clone(),
resourcepart: None,
}
}
pub fn as_full(&self) -> Result<&Self, JIDError> {
if let Some(_) = self.resourcepart {
Ok(&self)
} else {
Err(JIDError::NoResourcePart)
}
}
}
impl FromStr for JID {
type Err = JIDParseError;
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let split: Vec<&str> = s.split('@').collect();
match split.len() {
0 => Err(JIDParseError::Empty),
0 => Err(ParseError::Empty),
1 => {
let split: Vec<&str> = split[0].split('/').collect();
match split.len() {
@ -44,7 +65,7 @@ impl FromStr for JID {
split[0].to_string(),
Some(split[1].to_string()),
)),
_ => Err(JIDParseError::Malformed),
_ => Err(ParseError::Malformed),
}
}
2 => {
@ -60,16 +81,16 @@ impl FromStr for JID {
split2[0].to_string(),
Some(split2[1].to_string()),
)),
_ => Err(JIDParseError::Malformed),
_ => Err(ParseError::Malformed),
}
}
_ => Err(JIDParseError::Malformed),
_ => Err(ParseError::Malformed),
}
}
}
impl TryFrom<String> for JID {
type Error = JIDParseError;
type Error = ParseError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()

View File

@ -27,16 +27,26 @@ mod tests {
// println!("{:?}", jabber.get_sockets().await)
// }
// #[tokio::test]
// async fn connect() {
// Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned())
// .unwrap()
// .connect()
// .await
// .unwrap()
// .ensure_tls()
// .await
// .unwrap()
// .start_stream()
// .await
// .unwrap();
// }
#[tokio::test]
async fn connect() {
Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned())
.connect()
.await
async fn login() {
Jabber::new(JID::from_str("test@blos.sm").unwrap(), "slayed".to_owned())
.unwrap()
.ensure_tls()
.await
.unwrap()
.start_stream()
.login()
.await
.unwrap();
}

View File

@ -1 +1,2 @@
pub mod sasl;
pub mod stream;

32
src/stanza/sasl.rs Normal file
View File

@ -0,0 +1,32 @@
use serde::{Deserialize, Serialize};
#[derive(Deserialize, PartialEq, Debug)]
pub struct Mechanisms {
#[serde(rename = "$value")]
pub mechanisms: Vec<Mechanism>,
}
#[derive(Deserialize, PartialEq, Debug)]
pub struct Mechanism {
#[serde(rename = "$text")]
pub mechanism: String,
}
#[derive(Serialize, Debug)]
#[serde(rename = "auth")]
pub struct Auth {
#[serde(rename = "@xmlns")]
pub ns: String,
#[serde(rename = "@mechanism")]
pub mechanism: String,
#[serde(rename = "$text")]
pub sasl_data: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct Challenge {
#[serde(rename = "@xmlns")]
pub ns: String,
#[serde(rename = "$text")]
pub sasl_data: String,
}

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use super::sasl::Mechanisms;
#[derive(Serialize, Deserialize)]
#[serde(rename = "stream:stream")]
struct Stream {
@ -31,6 +33,9 @@ pub enum StreamFeature {
#[serde(rename = "starttls")]
StartTls,
// TODO: other stream features
Sasl,
#[serde(rename = "mechanisms")]
Sasl(Mechanisms),
Bind,
#[serde(other)]
Unknown,
}