From 6a5e39c60ad74c1cba84daa7c845c8f0237a5d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Mon, 19 Jun 2023 19:23:54 +0100 Subject: [PATCH] implement starttls --- Cargo.toml | 5 +- src/client/encrypted.rs | 59 ++++++++++++ src/client/mod.rs | 40 ++++++++ src/client/unencrypted.rs | 135 +++++++++++++++++++++++++++ src/error.rs | 7 ++ src/jabber.rs | 131 ++++++++++++++++++++++++++ src/lib.rs | 187 ++++++-------------------------------- src/stanza/mod.rs | 1 + src/stanza/stream.rs | 36 ++++++++ 9 files changed, 441 insertions(+), 160 deletions(-) create mode 100644 src/client/encrypted.rs create mode 100644 src/client/mod.rs create mode 100644 src/client/unencrypted.rs create mode 100644 src/error.rs create mode 100644 src/jabber.rs create mode 100644 src/stanza/mod.rs create mode 100644 src/stanza/stream.rs diff --git a/Cargo.toml b/Cargo.toml index 12a7f4e..472bbf7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -quick-xml = { version = "0.29.0", features = ["async-tokio"] } +async-trait = "0.1.68" +quick-xml = { version = "0.29.0", features = ["async-tokio", "serialize"] } +serde = { version = "1.0.164", features = ["derive"] } tokio = { version = "1.28", features = ["full"] } +tokio-native-tls = "0.3.1" trust-dns-resolver = "0.22.0" diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs new file mode 100644 index 0000000..08439b2 --- /dev/null +++ b/src/client/encrypted.rs @@ -0,0 +1,59 @@ +use quick_xml::{ + events::{BytesDecl, BytesStart, Event}, + Reader, Writer, +}; +use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_native_tls::TlsStream; + +use crate::Jabber; +use crate::Result; + +pub struct JabberClient<'j> { + reader: Reader>>>, + writer: Writer>>, + jabber: &'j mut Jabber<'j>, +} + +impl<'j> JabberClient<'j> { + pub fn new( + reader: Reader>>>, + writer: Writer>>, + jabber: &'j mut Jabber<'j>, + ) -> Self { + Self { + reader, + writer, + jabber, + } + } + + pub async fn start_stream(&mut self) -> Result<()> { + let declaration = BytesDecl::new("1.0", None, None); + let mut stream_element = BytesStart::new("stream:stream"); + stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes())); + stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes())); + stream_element.push_attribute(("version", "1.0")); + stream_element.push_attribute(("xml:lang", "en")); + stream_element.push_attribute(("xmlns", "jabber:client")); + stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); + self.writer + .write_event_async(Event::Decl(declaration)) + .await; + self.writer + .write_event_async(Event::Start(stream_element)) + .await + .unwrap(); + let mut buf = Vec::new(); + loop { + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::Start(e) => { + println!("{:?}", e); + break; + } + e => println!("decl: {:?}", e), + }; + } + Ok(()) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..fe3dd34 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,40 @@ +pub mod encrypted; +pub mod unencrypted; + +// use async_trait::async_trait; + +use crate::stanza::stream::StreamFeature; +use crate::JabberError; +use crate::Result; + +pub enum JabberClientType<'j> { + Encrypted(encrypted::JabberClient<'j>), + Unencrypted(unencrypted::JabberClient<'j>), +} + +impl<'j> JabberClientType<'j> { + pub async fn ensure_tls(self) -> Result> { + match self { + Self::Encrypted(mut c) => { + c.start_stream(); + Ok(c) + } + Self::Unencrypted(mut c) => { + c.start_stream().await?; + let features = c.get_features().await?; + if features.contains(&StreamFeature::StartTls) { + Ok(c.starttls().await?) + } else { + Err(JabberError::StartTlsUnavailable) + } + } + } + } +} + +// TODO: jabber client trait over both client types +// #[async_trait] +// pub trait JabberTrait { +// async fn start_stream(&mut self) -> Result<()>; +// async fn get_features(&self) -> Result>; +// } diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs new file mode 100644 index 0000000..7528b14 --- /dev/null +++ b/src/client/unencrypted.rs @@ -0,0 +1,135 @@ +use std::str; + +use quick_xml::{ + de::Deserializer, + events::{BytesDecl, BytesStart, Event}, + name::QName, + Reader, Writer, +}; +use serde::Deserialize; +use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; + +use crate::Result; +use crate::{error::JabberError, stanza::stream::StreamFeature}; +use crate::{stanza::stream::StreamFeatures, Jabber}; + +pub struct JabberClient<'j> { + reader: Reader>>, + writer: Writer>, + jabber: &'j mut Jabber<'j>, +} + +impl<'j> JabberClient<'j> { + pub fn new( + reader: Reader>>, + writer: Writer>, + jabber: &'j mut Jabber<'j>, + ) -> Self { + Self { + reader, + writer, + jabber, + } + } + + pub async fn start_stream(&mut self) -> Result<()> { + let declaration = BytesDecl::new("1.0", None, None); + let mut stream_element = BytesStart::new("stream:stream"); + stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes())); + stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes())); + stream_element.push_attribute(("version", "1.0")); + stream_element.push_attribute(("xml:lang", "en")); + stream_element.push_attribute(("xmlns", "jabber:client")); + stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); + self.writer + .write_event_async(Event::Decl(declaration)) + .await; + self.writer + .write_event_async(Event::Start(stream_element)) + .await + .unwrap(); + let mut buf = Vec::new(); + loop { + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::Start(e) => { + println!("{:?}", e); + break; + } + Event::Decl(e) => println!("decl: {:?}", e), + _ => return Err(JabberError::BadStream), + } + } + Ok(()) + } + + pub async fn get_features(&mut self) -> Result> { + let mut buf = Vec::new(); + let mut txt = Vec::new(); + let mut loop_end = false; + while !loop_end { + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::End(e) => { + if e.name() == QName(b"stream:features") { + loop_end = true; + } + } + _ => (), + } + txt.push(b'<'); + txt = txt + .into_iter() + .chain(buf.to_owned()) + .chain(vec![b'>']) + .collect(); + buf.clear(); + } + println!("{:?}", txt); + let decoded = str::from_utf8(&txt).unwrap(); + println!("decoded: {:?}", decoded); + let mut deserializer = Deserializer::from_str(decoded); + // let mut deserializer = Deserializer::from_str(txt); + let features = StreamFeatures::deserialize(&mut deserializer).unwrap(); + println!("{:?}", features); + Ok(features.features) + } + + pub async fn starttls(mut self) -> Result> { + let mut starttls_element = BytesStart::new("starttls"); + starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); + self.writer + .write_event_async(Event::Empty(starttls_element)) + .await + .unwrap(); + let mut buf = Vec::new(); + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::Empty(e) => match e.name() { + QName(b"proceed") => { + let connector = TlsConnector::new().unwrap(); + let stream = self + .reader + .into_inner() + .into_inner() + .unsplit(self.writer.into_inner()); + if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) + .connect(&self.jabber.server, stream) + .await + { + let (read, write) = tokio::io::split(tlsstream); + let reader = Reader::from_reader(BufReader::new(read)); + let writer = Writer::new(write); + return Ok(super::encrypted::JabberClient::new( + reader, + writer, + self.jabber, + )); + } + } + QName(_) => return Err(JabberError::TlsNegotiation), + }, + _ => return Err(JabberError::TlsNegotiation), + } + Err(JabberError::TlsNegotiation) + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..a632537 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,7 @@ +#[derive(Debug)] +pub enum JabberError { + ConnectionError, + BadStream, + StartTlsUnavailable, + TlsNegotiation, +} diff --git a/src/jabber.rs b/src/jabber.rs new file mode 100644 index 0000000..a1f6272 --- /dev/null +++ b/src/jabber.rs @@ -0,0 +1,131 @@ +use std::marker::PhantomData; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; + +use quick_xml::{Reader, Writer}; +use tokio::io::BufReader; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; + +use crate::client; +use crate::client::JabberClientType; +use crate::jid::JID; +use crate::{JabberError, Result}; + +pub struct Jabber<'j> { + pub jid: JID, + pub password: String, + pub server: String, + _marker: PhantomData<&'j ()>, +} + +impl<'j> Jabber<'j> { + pub fn new(jid: JID, password: String) -> Self { + let server = jid.domainpart.clone(); + Self { + jid, + password, + server, + _marker: PhantomData, + } + } + + async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> { + let mut socket_addrs = Vec::new(); + + // if it's a socket/ip then just return that + + // socket + if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) { + match socket_addr.port() { + 5223 => socket_addrs.push((socket_addr, true)), + _ => socket_addrs.push((socket_addr, false)), + } + + return socket_addrs; + } + // ip + if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) { + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + return socket_addrs; + } + + // otherwise resolve + if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), false)) + } + }); + } + } + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), true)) + } + }); + } + } + + // in case cannot connect through SRV records + resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + } + }); + } + socket_addrs + } + + pub async fn connect(&'j mut self) -> Result { + for (socket_addr, is_tls) in self.get_sockets().await { + println!("trying {}", socket_addr); + match is_tls { + true => { + let socket = TcpStream::connect(socket_addr).await.unwrap(); + let connector = TlsConnector::new().unwrap(); + if let Ok(stream) = tokio_native_tls::TlsConnector::from(connector) + .connect(&self.server, socket) + .await + { + let (read, write) = tokio::io::split(stream); + let reader = Reader::from_reader(BufReader::new(read)); + let writer = Writer::new(write); + return Ok(JabberClientType::Encrypted( + client::encrypted::JabberClient::new(reader, writer, self), + )); + } + } + false => { + if let Ok(stream) = TcpStream::connect(socket_addr).await { + let (read, write) = tokio::io::split(stream); + let reader = Reader::from_reader(BufReader::new(read)); + let writer = Writer::new(write); + return Ok(JabberClientType::Unencrypted( + client::unencrypted::JabberClient::new(reader, writer, self), + )); + } + } + } + } + Err(JabberError::ConnectionError) + } +} diff --git a/src/lib.rs b/src/lib.rs index 10c7172..7f1433d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,174 +1,43 @@ -// TODO: logging (dropped errors) #![allow(unused_must_use)] -use std::{ - net::{IpAddr, SocketAddr}, - str::FromStr, -}; - -use jid::JID; -use quick_xml::{Reader, Writer}; -use tokio::net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpStream, -}; +// TODO: logging (dropped errors) +pub mod client; +pub mod error; +pub mod jabber; pub mod jid; +pub mod stanza; -pub struct JabberData { - jid: jid::JID, - password: String, -} +pub use client::encrypted::JabberClient; +pub use error::JabberError; +pub use jabber::Jabber; +pub use jid::JID; -impl JabberData { - pub fn new(jid: JID, password: String) -> Self { - Self { jid, password } - } - - async fn get_sockets(&self) -> Vec { - let mut socket_addrs = Vec::new(); - - // if it's a socket/ip then just return that - - // socket - if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) { - socket_addrs.push(socket_addr); - return socket_addrs; - } - // ip - if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) { - socket_addrs.push(SocketAddr::new(ip, 5222)); - socket_addrs.push(SocketAddr::new(ip, 5223)); - return socket_addrs; - } - - // if port specified return name resolutions with specified port - - // otherwise resolve - if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { - if let Ok(lookup) = resolver - .srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart)) - .await - { - for srv in lookup { - resolver - .lookup_ip(srv.target().to_owned()) - .await - .map(|ips| { - for ip in ips { - socket_addrs.push(SocketAddr::new(ip, srv.port())) - } - }); - } - } - if let Ok(lookup) = resolver - .srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart)) - .await - { - for srv in lookup { - resolver - .lookup_ip(srv.target().to_owned()) - .await - .map(|ips| { - for ip in ips { - socket_addrs.push(SocketAddr::new(ip, srv.port())) - } - }); - } - } - - // in case cannot connect through SRV records - resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| { - for ip in ips { - socket_addrs.push(SocketAddr::new(ip, 5222)); - socket_addrs.push(SocketAddr::new(ip, 5223)); - } - }); - } - - socket_addrs - } -} - -pub struct Jabber { - reader: Reader, - writer: Writer, - data: JabberData, -} - -#[derive(Debug)] -pub enum JabberError { - NotConnected, -} - -impl Jabber { - pub async fn connect(data: JabberData) -> Result { - for socket_addr in data.get_sockets().await { - println!("trying {}", socket_addr); - if let Ok(stream) = TcpStream::connect(socket_addr).await { - println!("connected to {}", socket_addr); - let (read, write) = stream.into_split(); - return Ok(Self { - reader: Reader::from_reader(read), - writer: Writer::new(write), - data, - }); - } - } - Err(JabberError::NotConnected) - } - - async fn reconnect(&mut self) { - for socket_addr in self.data.get_sockets().await { - println!("trying {}", socket_addr); - if let Ok(stream) = TcpStream::connect(socket_addr).await { - println!("connected to {}", socket_addr); - let (read, write) = stream.into_split(); - self.reader = Reader::from_reader(read); - self.writer = Writer::new(write); - return; - } - } - println!("could not connect") - } - - async fn begin_stream(&mut self) -> Result<(), JabberError> { - todo!() - } - - async fn starttls() -> Result<(), JabberError> { - todo!() - } - - async fn directtls() -> Result<(), JabberError> { - todo!() - } - - async fn auth(&mut self) -> Result<(), JabberError> { - todo!() - } - - async fn close(&mut self) {} -} +pub type Result = std::result::Result; #[cfg(test)] mod tests { - use crate::jid::JID; + use std::str::FromStr; - use super::*; + use crate::Jabber; + use crate::JID; - #[tokio::test] - async fn get_sockets() { - let data = JabberData::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()); - println!("{:?}", data.get_sockets().await) - } + // #[tokio::test] + // async fn get_sockets() { + // let jabber = Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()); + // println!("{:?}", jabber.get_sockets().await) + // } #[tokio::test] async fn connect() { - Jabber::connect(JabberData::new( - JID::from_str("cel@blos.sm").unwrap(), - "password".to_owned(), - )) - .await - .unwrap(); + Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()) + .connect() + .await + .unwrap() + .ensure_tls() + .await + .unwrap() + .start_stream() + .await + .unwrap(); } } diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs new file mode 100644 index 0000000..baf29e0 --- /dev/null +++ b/src/stanza/mod.rs @@ -0,0 +1 @@ +pub mod stream; diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs new file mode 100644 index 0000000..dde741d --- /dev/null +++ b/src/stanza/stream.rs @@ -0,0 +1,36 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +#[serde(rename = "stream:stream")] +struct Stream { + #[serde(rename = "@from")] + from: Option, + #[serde(rename = "@id")] + id: Option, + #[serde(rename = "@to")] + to: Option, + #[serde(rename = "@version")] + version: Option, + #[serde(rename = "@xml:lang")] + lang: Option, + #[serde(rename = "@xmlns")] + namespace: Option, + #[serde(rename = "@xmlns:stream")] + stream_namespace: Option, +} + +#[derive(Deserialize, Debug)] +#[serde(rename = "stream:features")] +pub struct StreamFeatures { + #[serde(rename = "$value")] + pub features: Vec, +} + +#[derive(Deserialize, PartialEq, Debug)] +pub enum StreamFeature { + #[serde(rename = "starttls")] + StartTls, + // TODO: other stream features + Sasl, + Bind, +}