implement starttls

This commit is contained in:
cel 🌸 2023-06-19 19:23:54 +01:00
parent abc3ffa736
commit 6a5e39c60a
Signed by: cel
GPG Key ID: 48E29AF13B5F1349
9 changed files with 441 additions and 160 deletions

View File

@ -7,6 +7,9 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
quick-xml = { version = "0.29.0", features = ["async-tokio"] } async-trait = "0.1.68"
quick-xml = { version = "0.29.0", features = ["async-tokio", "serialize"] }
serde = { version = "1.0.164", features = ["derive"] }
tokio = { version = "1.28", features = ["full"] } tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"
trust-dns-resolver = "0.22.0" trust-dns-resolver = "0.22.0"

59
src/client/encrypted.rs Normal file
View File

@ -0,0 +1,59 @@
use quick_xml::{
events::{BytesDecl, BytesStart, Event},
Reader, Writer,
};
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;
use crate::Jabber;
use crate::Result;
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
jabber: &'j mut Jabber<'j>,
}
impl<'j> JabberClient<'j> {
pub fn new(
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
jabber: &'j mut Jabber<'j>,
) -> Self {
Self {
reader,
writer,
jabber,
}
}
pub async fn start_stream(&mut self) -> Result<()> {
let declaration = BytesDecl::new("1.0", None, None);
let mut stream_element = BytesStart::new("stream:stream");
stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
stream_element.push_attribute(("version", "1.0"));
stream_element.push_attribute(("xml:lang", "en"));
stream_element.push_attribute(("xmlns", "jabber:client"));
stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
self.writer
.write_event_async(Event::Decl(declaration))
.await;
self.writer
.write_event_async(Event::Start(stream_element))
.await
.unwrap();
let mut buf = Vec::new();
loop {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::Start(e) => {
println!("{:?}", e);
break;
}
e => println!("decl: {:?}", e),
};
}
Ok(())
}
}

40
src/client/mod.rs Normal file
View File

@ -0,0 +1,40 @@
pub mod encrypted;
pub mod unencrypted;
// use async_trait::async_trait;
use crate::stanza::stream::StreamFeature;
use crate::JabberError;
use crate::Result;
pub enum JabberClientType<'j> {
Encrypted(encrypted::JabberClient<'j>),
Unencrypted(unencrypted::JabberClient<'j>),
}
impl<'j> JabberClientType<'j> {
pub async fn ensure_tls(self) -> Result<encrypted::JabberClient<'j>> {
match self {
Self::Encrypted(mut c) => {
c.start_stream();
Ok(c)
}
Self::Unencrypted(mut c) => {
c.start_stream().await?;
let features = c.get_features().await?;
if features.contains(&StreamFeature::StartTls) {
Ok(c.starttls().await?)
} else {
Err(JabberError::StartTlsUnavailable)
}
}
}
}
}
// TODO: jabber client trait over both client types
// #[async_trait]
// pub trait JabberTrait {
// async fn start_stream(&mut self) -> Result<()>;
// async fn get_features(&self) -> Result<Vec<StreamFeatures>>;
// }

135
src/client/unencrypted.rs Normal file
View File

@ -0,0 +1,135 @@
use std::str;
use quick_xml::{
de::Deserializer,
events::{BytesDecl, BytesStart, Event},
name::QName,
Reader, Writer,
};
use serde::Deserialize;
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
use crate::Result;
use crate::{error::JabberError, stanza::stream::StreamFeature};
use crate::{stanza::stream::StreamFeatures, Jabber};
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
writer: Writer<WriteHalf<TcpStream>>,
jabber: &'j mut Jabber<'j>,
}
impl<'j> JabberClient<'j> {
pub fn new(
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
writer: Writer<WriteHalf<TcpStream>>,
jabber: &'j mut Jabber<'j>,
) -> Self {
Self {
reader,
writer,
jabber,
}
}
pub async fn start_stream(&mut self) -> Result<()> {
let declaration = BytesDecl::new("1.0", None, None);
let mut stream_element = BytesStart::new("stream:stream");
stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
stream_element.push_attribute(("version", "1.0"));
stream_element.push_attribute(("xml:lang", "en"));
stream_element.push_attribute(("xmlns", "jabber:client"));
stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
self.writer
.write_event_async(Event::Decl(declaration))
.await;
self.writer
.write_event_async(Event::Start(stream_element))
.await
.unwrap();
let mut buf = Vec::new();
loop {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::Start(e) => {
println!("{:?}", e);
break;
}
Event::Decl(e) => println!("decl: {:?}", e),
_ => return Err(JabberError::BadStream),
}
}
Ok(())
}
pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
let mut buf = Vec::new();
let mut txt = Vec::new();
let mut loop_end = false;
while !loop_end {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::End(e) => {
if e.name() == QName(b"stream:features") {
loop_end = true;
}
}
_ => (),
}
txt.push(b'<');
txt = txt
.into_iter()
.chain(buf.to_owned())
.chain(vec![b'>'])
.collect();
buf.clear();
}
println!("{:?}", txt);
let decoded = str::from_utf8(&txt).unwrap();
println!("decoded: {:?}", decoded);
let mut deserializer = Deserializer::from_str(decoded);
// let mut deserializer = Deserializer::from_str(txt);
let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
println!("{:?}", features);
Ok(features.features)
}
pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> {
let mut starttls_element = BytesStart::new("starttls");
starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls"));
self.writer
.write_event_async(Event::Empty(starttls_element))
.await
.unwrap();
let mut buf = Vec::new();
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::Empty(e) => match e.name() {
QName(b"proceed") => {
let connector = TlsConnector::new().unwrap();
let stream = self
.reader
.into_inner()
.into_inner()
.unsplit(self.writer.into_inner());
if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.jabber.server, stream)
.await
{
let (read, write) = tokio::io::split(tlsstream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(super::encrypted::JabberClient::new(
reader,
writer,
self.jabber,
));
}
}
QName(_) => return Err(JabberError::TlsNegotiation),
},
_ => return Err(JabberError::TlsNegotiation),
}
Err(JabberError::TlsNegotiation)
}
}

7
src/error.rs Normal file
View File

@ -0,0 +1,7 @@
#[derive(Debug)]
pub enum JabberError {
ConnectionError,
BadStream,
StartTlsUnavailable,
TlsNegotiation,
}

131
src/jabber.rs Normal file
View File

@ -0,0 +1,131 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use quick_xml::{Reader, Writer};
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
use crate::client;
use crate::client::JabberClientType;
use crate::jid::JID;
use crate::{JabberError, Result};
pub struct Jabber<'j> {
pub jid: JID,
pub password: String,
pub server: String,
_marker: PhantomData<&'j ()>,
}
impl<'j> Jabber<'j> {
pub fn new(jid: JID, password: String) -> Self {
let server = jid.domainpart.clone();
Self {
jid,
password,
server,
_marker: PhantomData,
}
}
async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> {
let mut socket_addrs = Vec::new();
// if it's a socket/ip then just return that
// socket
if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) {
match socket_addr.port() {
5223 => socket_addrs.push((socket_addr, true)),
_ => socket_addrs.push((socket_addr, false)),
}
return socket_addrs;
}
// ip
if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) {
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
return socket_addrs;
}
// otherwise resolve
if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
}
});
}
}
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
}
});
}
}
// in case cannot connect through SRV records
resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
}
});
}
socket_addrs
}
pub async fn connect(&'j mut self) -> Result<JabberClientType> {
for (socket_addr, is_tls) in self.get_sockets().await {
println!("trying {}", socket_addr);
match is_tls {
true => {
let socket = TcpStream::connect(socket_addr).await.unwrap();
let connector = TlsConnector::new().unwrap();
if let Ok(stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.server, socket)
.await
{
let (read, write) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(JabberClientType::Encrypted(
client::encrypted::JabberClient::new(reader, writer, self),
));
}
}
false => {
if let Ok(stream) = TcpStream::connect(socket_addr).await {
let (read, write) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(JabberClientType::Unencrypted(
client::unencrypted::JabberClient::new(reader, writer, self),
));
}
}
}
}
Err(JabberError::ConnectionError)
}
}

View File

@ -1,174 +1,43 @@
// TODO: logging (dropped errors)
#![allow(unused_must_use)] #![allow(unused_must_use)]
use std::{
net::{IpAddr, SocketAddr},
str::FromStr,
};
use jid::JID;
use quick_xml::{Reader, Writer};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
// TODO: logging (dropped errors)
pub mod client;
pub mod error;
pub mod jabber;
pub mod jid; pub mod jid;
pub mod stanza;
pub struct JabberData { pub use client::encrypted::JabberClient;
jid: jid::JID, pub use error::JabberError;
password: String, pub use jabber::Jabber;
} pub use jid::JID;
impl JabberData { pub type Result<T> = std::result::Result<T, JabberError>;
pub fn new(jid: JID, password: String) -> Self {
Self { jid, password }
}
async fn get_sockets(&self) -> Vec<SocketAddr> {
let mut socket_addrs = Vec::new();
// if it's a socket/ip then just return that
// socket
if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) {
socket_addrs.push(socket_addr);
return socket_addrs;
}
// ip
if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) {
socket_addrs.push(SocketAddr::new(ip, 5222));
socket_addrs.push(SocketAddr::new(ip, 5223));
return socket_addrs;
}
// if port specified return name resolutions with specified port
// otherwise resolve
if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push(SocketAddr::new(ip, srv.port()))
}
});
}
}
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push(SocketAddr::new(ip, srv.port()))
}
});
}
}
// in case cannot connect through SRV records
resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| {
for ip in ips {
socket_addrs.push(SocketAddr::new(ip, 5222));
socket_addrs.push(SocketAddr::new(ip, 5223));
}
});
}
socket_addrs
}
}
pub struct Jabber {
reader: Reader<OwnedReadHalf>,
writer: Writer<OwnedWriteHalf>,
data: JabberData,
}
#[derive(Debug)]
pub enum JabberError {
NotConnected,
}
impl Jabber {
pub async fn connect(data: JabberData) -> Result<Self, JabberError> {
for socket_addr in data.get_sockets().await {
println!("trying {}", socket_addr);
if let Ok(stream) = TcpStream::connect(socket_addr).await {
println!("connected to {}", socket_addr);
let (read, write) = stream.into_split();
return Ok(Self {
reader: Reader::from_reader(read),
writer: Writer::new(write),
data,
});
}
}
Err(JabberError::NotConnected)
}
async fn reconnect(&mut self) {
for socket_addr in self.data.get_sockets().await {
println!("trying {}", socket_addr);
if let Ok(stream) = TcpStream::connect(socket_addr).await {
println!("connected to {}", socket_addr);
let (read, write) = stream.into_split();
self.reader = Reader::from_reader(read);
self.writer = Writer::new(write);
return;
}
}
println!("could not connect")
}
async fn begin_stream(&mut self) -> Result<(), JabberError> {
todo!()
}
async fn starttls() -> Result<(), JabberError> {
todo!()
}
async fn directtls() -> Result<(), JabberError> {
todo!()
}
async fn auth(&mut self) -> Result<(), JabberError> {
todo!()
}
async fn close(&mut self) {}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::jid::JID; use std::str::FromStr;
use super::*; use crate::Jabber;
use crate::JID;
#[tokio::test] // #[tokio::test]
async fn get_sockets() { // async fn get_sockets() {
let data = JabberData::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()); // let jabber = Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned());
println!("{:?}", data.get_sockets().await) // println!("{:?}", jabber.get_sockets().await)
} // }
#[tokio::test] #[tokio::test]
async fn connect() { async fn connect() {
Jabber::connect(JabberData::new( Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned())
JID::from_str("cel@blos.sm").unwrap(), .connect()
"password".to_owned(), .await
)) .unwrap()
.await .ensure_tls()
.unwrap(); .await
.unwrap()
.start_stream()
.await
.unwrap();
} }
} }

1
src/stanza/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod stream;

36
src/stanza/stream.rs Normal file
View File

@ -0,0 +1,36 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[serde(rename = "stream:stream")]
struct Stream {
#[serde(rename = "@from")]
from: Option<String>,
#[serde(rename = "@id")]
id: Option<String>,
#[serde(rename = "@to")]
to: Option<String>,
#[serde(rename = "@version")]
version: Option<f32>,
#[serde(rename = "@xml:lang")]
lang: Option<String>,
#[serde(rename = "@xmlns")]
namespace: Option<String>,
#[serde(rename = "@xmlns:stream")]
stream_namespace: Option<String>,
}
#[derive(Deserialize, Debug)]
#[serde(rename = "stream:features")]
pub struct StreamFeatures {
#[serde(rename = "$value")]
pub features: Vec<StreamFeature>,
}
#[derive(Deserialize, PartialEq, Debug)]
pub enum StreamFeature {
#[serde(rename = "starttls")]
StartTls,
// TODO: other stream features
Sasl,
Bind,
}