Compare commits

...

5 Commits

Author SHA1 Message Date
cel 🌸 6a5e39c60a
implement starttls 2023-06-19 19:23:54 +01:00
cel 🌸 abc3ffa736
refactor jabber client 2023-06-16 17:13:01 +01:00
cel 🌸 bcacf42dec
implement client socket resolution 2023-06-16 14:49:20 +01:00
cel 🌸 e9c742f4a9
make JID struct etc. public 2023-06-16 14:48:19 +01:00
cel 🌸 9cdf4953fe
simplify domainpart 2023-06-13 00:53:11 +01:00
10 changed files with 464 additions and 63 deletions

View File

@ -1,5 +1,5 @@
[package]
name = "lâmpada"
name = "lampada"
authors = ["cel <cel@blos.sm>"]
version = "0.0.1"
edition = "2021"
@ -7,3 +7,9 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
async-trait = "0.1.68"
quick-xml = { version = "0.29.0", features = ["async-tokio", "serialize"] }
serde = { version = "1.0.164", features = ["derive"] }
tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"
trust-dns-resolver = "0.22.0"

59
src/client/encrypted.rs Normal file
View File

@ -0,0 +1,59 @@
use quick_xml::{
events::{BytesDecl, BytesStart, Event},
Reader, Writer,
};
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;
use crate::Jabber;
use crate::Result;
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
jabber: &'j mut Jabber<'j>,
}
impl<'j> JabberClient<'j> {
pub fn new(
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
jabber: &'j mut Jabber<'j>,
) -> Self {
Self {
reader,
writer,
jabber,
}
}
pub async fn start_stream(&mut self) -> Result<()> {
let declaration = BytesDecl::new("1.0", None, None);
let mut stream_element = BytesStart::new("stream:stream");
stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
stream_element.push_attribute(("version", "1.0"));
stream_element.push_attribute(("xml:lang", "en"));
stream_element.push_attribute(("xmlns", "jabber:client"));
stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
self.writer
.write_event_async(Event::Decl(declaration))
.await;
self.writer
.write_event_async(Event::Start(stream_element))
.await
.unwrap();
let mut buf = Vec::new();
loop {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::Start(e) => {
println!("{:?}", e);
break;
}
e => println!("decl: {:?}", e),
};
}
Ok(())
}
}

40
src/client/mod.rs Normal file
View File

@ -0,0 +1,40 @@
pub mod encrypted;
pub mod unencrypted;
// use async_trait::async_trait;
use crate::stanza::stream::StreamFeature;
use crate::JabberError;
use crate::Result;
pub enum JabberClientType<'j> {
Encrypted(encrypted::JabberClient<'j>),
Unencrypted(unencrypted::JabberClient<'j>),
}
impl<'j> JabberClientType<'j> {
pub async fn ensure_tls(self) -> Result<encrypted::JabberClient<'j>> {
match self {
Self::Encrypted(mut c) => {
c.start_stream();
Ok(c)
}
Self::Unencrypted(mut c) => {
c.start_stream().await?;
let features = c.get_features().await?;
if features.contains(&StreamFeature::StartTls) {
Ok(c.starttls().await?)
} else {
Err(JabberError::StartTlsUnavailable)
}
}
}
}
}
// TODO: jabber client trait over both client types
// #[async_trait]
// pub trait JabberTrait {
// async fn start_stream(&mut self) -> Result<()>;
// async fn get_features(&self) -> Result<Vec<StreamFeatures>>;
// }

135
src/client/unencrypted.rs Normal file
View File

@ -0,0 +1,135 @@
use std::str;
use quick_xml::{
de::Deserializer,
events::{BytesDecl, BytesStart, Event},
name::QName,
Reader, Writer,
};
use serde::Deserialize;
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
use crate::Result;
use crate::{error::JabberError, stanza::stream::StreamFeature};
use crate::{stanza::stream::StreamFeatures, Jabber};
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
writer: Writer<WriteHalf<TcpStream>>,
jabber: &'j mut Jabber<'j>,
}
impl<'j> JabberClient<'j> {
pub fn new(
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
writer: Writer<WriteHalf<TcpStream>>,
jabber: &'j mut Jabber<'j>,
) -> Self {
Self {
reader,
writer,
jabber,
}
}
pub async fn start_stream(&mut self) -> Result<()> {
let declaration = BytesDecl::new("1.0", None, None);
let mut stream_element = BytesStart::new("stream:stream");
stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
stream_element.push_attribute(("version", "1.0"));
stream_element.push_attribute(("xml:lang", "en"));
stream_element.push_attribute(("xmlns", "jabber:client"));
stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
self.writer
.write_event_async(Event::Decl(declaration))
.await;
self.writer
.write_event_async(Event::Start(stream_element))
.await
.unwrap();
let mut buf = Vec::new();
loop {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::Start(e) => {
println!("{:?}", e);
break;
}
Event::Decl(e) => println!("decl: {:?}", e),
_ => return Err(JabberError::BadStream),
}
}
Ok(())
}
pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
let mut buf = Vec::new();
let mut txt = Vec::new();
let mut loop_end = false;
while !loop_end {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::End(e) => {
if e.name() == QName(b"stream:features") {
loop_end = true;
}
}
_ => (),
}
txt.push(b'<');
txt = txt
.into_iter()
.chain(buf.to_owned())
.chain(vec![b'>'])
.collect();
buf.clear();
}
println!("{:?}", txt);
let decoded = str::from_utf8(&txt).unwrap();
println!("decoded: {:?}", decoded);
let mut deserializer = Deserializer::from_str(decoded);
// let mut deserializer = Deserializer::from_str(txt);
let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
println!("{:?}", features);
Ok(features.features)
}
pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> {
let mut starttls_element = BytesStart::new("starttls");
starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls"));
self.writer
.write_event_async(Event::Empty(starttls_element))
.await
.unwrap();
let mut buf = Vec::new();
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
Event::Empty(e) => match e.name() {
QName(b"proceed") => {
let connector = TlsConnector::new().unwrap();
let stream = self
.reader
.into_inner()
.into_inner()
.unsplit(self.writer.into_inner());
if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.jabber.server, stream)
.await
{
let (read, write) = tokio::io::split(tlsstream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(super::encrypted::JabberClient::new(
reader,
writer,
self.jabber,
));
}
}
QName(_) => return Err(JabberError::TlsNegotiation),
},
_ => return Err(JabberError::TlsNegotiation),
}
Err(JabberError::TlsNegotiation)
}
}

7
src/error.rs Normal file
View File

@ -0,0 +1,7 @@
#[derive(Debug)]
pub enum JabberError {
ConnectionError,
BadStream,
StartTlsUnavailable,
TlsNegotiation,
}

131
src/jabber.rs Normal file
View File

@ -0,0 +1,131 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use quick_xml::{Reader, Writer};
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
use crate::client;
use crate::client::JabberClientType;
use crate::jid::JID;
use crate::{JabberError, Result};
pub struct Jabber<'j> {
pub jid: JID,
pub password: String,
pub server: String,
_marker: PhantomData<&'j ()>,
}
impl<'j> Jabber<'j> {
pub fn new(jid: JID, password: String) -> Self {
let server = jid.domainpart.clone();
Self {
jid,
password,
server,
_marker: PhantomData,
}
}
async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> {
let mut socket_addrs = Vec::new();
// if it's a socket/ip then just return that
// socket
if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) {
match socket_addr.port() {
5223 => socket_addrs.push((socket_addr, true)),
_ => socket_addrs.push((socket_addr, false)),
}
return socket_addrs;
}
// ip
if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) {
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
return socket_addrs;
}
// otherwise resolve
if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
}
});
}
}
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
}
});
}
}
// in case cannot connect through SRV records
resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
}
});
}
socket_addrs
}
pub async fn connect(&'j mut self) -> Result<JabberClientType> {
for (socket_addr, is_tls) in self.get_sockets().await {
println!("trying {}", socket_addr);
match is_tls {
true => {
let socket = TcpStream::connect(socket_addr).await.unwrap();
let connector = TlsConnector::new().unwrap();
if let Ok(stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.server, socket)
.await
{
let (read, write) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(JabberClientType::Encrypted(
client::encrypted::JabberClient::new(reader, writer, self),
));
}
}
false => {
if let Ok(stream) = TcpStream::connect(socket_addr).await {
let (read, write) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(JabberClientType::Unencrypted(
client::unencrypted::JabberClient::new(reader, writer, self),
));
}
}
}
}
Err(JabberError::ConnectionError)
}
}

View File

@ -1,68 +1,31 @@
use std::{
net::{Ipv4Addr, Ipv6Addr},
str::FromStr,
};
use std::str::FromStr;
#[derive(PartialEq, Debug)]
struct JID {
pub struct JID {
// TODO: validate localpart (length, char]
localpart: Option<String>,
domainpart: Domainpart,
resourcepart: Option<String>,
}
#[derive(PartialEq, Debug)]
enum Domainpart {
IPLiteral(Ipv6Addr),
IPv4Address(Ipv4Addr),
// TODO: domain name type, not string
IFQDN(String),
}
impl FromStr for Domainpart {
type Err = DomainpartParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.parse::<Ipv6Addr>() {
Ok(ip) => Ok(Domainpart::IPLiteral(ip)),
Err(_) => match s.parse::<Ipv4Addr>() {
Ok(ip) => Ok(Domainpart::IPv4Address(ip)),
Err(_) => Ok(Domainpart::IFQDN(s.to_owned())),
},
}
}
}
impl TryFrom<String> for Domainpart {
type Error = DomainpartParseError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()
}
pub localpart: Option<String>,
pub domainpart: String,
pub resourcepart: Option<String>,
}
#[derive(Debug)]
enum DomainpartParseError {}
#[derive(Debug)]
enum JIDParseError {
pub enum JIDParseError {
Empty,
Domainpart(DomainpartParseError),
Malformed,
}
impl JID {
fn new(localpart: Option<String>, domainpart: String, resourcepart: Option<String>) -> Self {
pub fn new(
localpart: Option<String>,
domainpart: String,
resourcepart: Option<String>,
) -> Self {
Self {
localpart,
domainpart: domainpart.parse().unwrap(),
resourcepart,
}
}
fn validate(&self) -> bool {
todo!()
}
}
impl FromStr for JID {
@ -119,11 +82,7 @@ impl std::fmt::Display for JID {
f,
"{}{}{}",
self.localpart.clone().map(|l| l + "@").unwrap_or_default(),
match &self.domainpart {
Domainpart::IPLiteral(addr) => addr.to_string(),
Domainpart::IPv4Address(addr) => addr.to_string(),
Domainpart::IFQDN(domain) => domain.to_owned(),
},
self.domainpart,
self.resourcepart
.clone()
.map(|r| "/".to_owned() + &r)

View File

@ -1,16 +1,43 @@
mod jid;
#![allow(unused_must_use)]
pub fn add(left: usize, right: usize) -> usize {
left + right
}
// TODO: logging (dropped errors)
pub mod client;
pub mod error;
pub mod jabber;
pub mod jid;
pub mod stanza;
pub use client::encrypted::JabberClient;
pub use error::JabberError;
pub use jabber::Jabber;
pub use jid::JID;
pub type Result<T> = std::result::Result<T, JabberError>;
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
use crate::Jabber;
use crate::JID;
// #[tokio::test]
// async fn get_sockets() {
// let jabber = Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned());
// println!("{:?}", jabber.get_sockets().await)
// }
#[tokio::test]
async fn connect() {
Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned())
.connect()
.await
.unwrap()
.ensure_tls()
.await
.unwrap()
.start_stream()
.await
.unwrap();
}
}

1
src/stanza/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod stream;

36
src/stanza/stream.rs Normal file
View File

@ -0,0 +1,36 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[serde(rename = "stream:stream")]
struct Stream {
#[serde(rename = "@from")]
from: Option<String>,
#[serde(rename = "@id")]
id: Option<String>,
#[serde(rename = "@to")]
to: Option<String>,
#[serde(rename = "@version")]
version: Option<f32>,
#[serde(rename = "@xml:lang")]
lang: Option<String>,
#[serde(rename = "@xmlns")]
namespace: Option<String>,
#[serde(rename = "@xmlns:stream")]
stream_namespace: Option<String>,
}
#[derive(Deserialize, Debug)]
#[serde(rename = "stream:features")]
pub struct StreamFeatures {
#[serde(rename = "$value")]
pub features: Vec<StreamFeature>,
}
#[derive(Deserialize, PartialEq, Debug)]
pub enum StreamFeature {
#[serde(rename = "starttls")]
StartTls,
// TODO: other stream features
Sasl,
Bind,
}