implement starttls

This commit is contained in:
cel 🌸 2024-11-24 02:04:45 +00:00
parent 40024d2dad
commit 35f164cdb6
5 changed files with 290 additions and 48 deletions

View File

@ -27,8 +27,11 @@ impl Connection {
match self {
Connection::Encrypted(j) => Ok(j),
Connection::Unencrypted(mut j) => {
j.start_stream().await?;
info!("upgrading connection to tls");
Ok(j.starttls().await?)
j.get_features().await?;
let j = j.starttls().await?;
Ok(j)
}
}
}
@ -179,4 +182,14 @@ mod tests {
async fn connect() {
Connection::connect("blos.sm").await.unwrap();
}
#[test(tokio::test)]
async fn test_tls() {
Connection::connect("blos.sm")
.await
.unwrap()
.ensure_tls()
.await
.unwrap();
}
}

View File

@ -1,14 +1,18 @@
use std::str;
use std::sync::Arc;
use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::SASLConfig;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, trace};
use trust_dns_resolver::proto::rr::domain::IntoLabel;
use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::stanza::stream::Stream;
use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Features, Stream};
use crate::stanza::XML_VERSION;
use crate::Result;
use crate::JID;
@ -62,7 +66,6 @@ where
// opening stream element
let server = self.server.clone().try_into()?;
let stream = Stream::new_client(None, server, None, "en".to_string());
// TODO: nicer function to serialize to xml writer
self.writer.write_start(&stream).await?;
// server to client
@ -72,57 +75,53 @@ where
// receive stream element and validate
let stream: Stream = self.reader.read_start().await?;
debug!("got stream: {:?}", stream);
if let Some(from) = stream.from {
self.server = from.to_string()
}
Ok(())
}
pub async fn get_features(&mut self) -> Result<Features> {
debug!("getting features");
let features: Features = self.reader.read().await?;
debug!("got features: {:?}", features);
Ok(features)
}
pub fn into_inner(self) -> S {
self.reader.into_inner().unsplit(self.writer.into_inner())
}
}
// pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
// Element::read(&mut self.reader).await?.try_into()
// }
impl Jabber<Unencrypted> {
pub async fn starttls(&mut self) -> Result<Jabber<Tls>> {
todo!()
pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
self.writer
.write_full(&StartTls { required: false })
.await?;
let proceed: Proceed = self.reader.read().await?;
debug!("got proceed: {:?}", proceed);
let connector = TlsConnector::new().unwrap();
let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.server, stream)
.await
{
let (read, write) = tokio::io::split(tlsstream);
let mut client = Jabber::new(
read,
write,
self.jid.to_owned(),
self.auth.to_owned(),
self.server.to_owned(),
);
client.start_stream().await?;
return Ok(client);
} else {
return Err(Error::Connection);
}
}
// let mut starttls_element = BytesStart::new("starttls");
// starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls"));
// self.writer
// .write_event_async(Event::Empty(starttls_element))
// .await
// .unwrap();
// let mut buf = Vec::new();
// match self.reader.read_event_into_async(&mut buf).await.unwrap() {
// Event::Empty(e) => match e.name() {
// QName(b"proceed") => {
// let connector = TlsConnector::new().unwrap();
// let stream = self
// .reader
// .into_inner()
// .into_inner()
// .unsplit(self.writer.into_inner());
// if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
// .connect(&self.jabber.server, stream)
// .await
// {
// let (read, write) = tokio::io::split(tlsstream);
// let reader = Reader::from_reader(BufReader::new(read));
// let writer = Writer::new(write);
// let mut client =
// super::encrypted::JabberClient::new(reader, writer, self.jabber);
// client.start_stream().await?;
// return Ok(client);
// }
// }
// QName(_) => return Err(JabberError::TlsNegotiation),
// },
// _ => return Err(JabberError::TlsNegotiation),
// }
// Err(JabberError::TlsNegotiation)
// }
}
impl std::fmt::Debug for Jabber<Tls> {

View File

@ -8,9 +8,6 @@ pub mod jabber;
pub mod jid;
pub mod stanza;
#[macro_use]
extern crate lazy_static;
pub use connection::Connection;
pub use error::Error;
pub use jabber::Jabber;

View File

@ -1 +1,163 @@
use std::collections::{HashMap, HashSet};
use peanuts::{
element::{Content, FromElement, IntoElement, Name, NamespaceDeclaration},
Element,
};
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
#[derive(Debug)]
pub struct StartTls {
pub required: bool,
}
impl IntoElement for StartTls {
fn into_element(&self) -> peanuts::Element {
let content;
if self.required == true {
let element = Content::Element(Element {
name: Name {
namespace: Some(XMLNS.to_string()),
local_name: "required".to_string(),
},
namespace_declarations: HashSet::new(),
attributes: HashMap::new(),
content: Vec::new(),
});
content = vec![element];
} else {
content = Vec::new();
}
let mut namespace_declarations = HashSet::new();
namespace_declarations.insert(NamespaceDeclaration {
prefix: None,
namespace: XMLNS.to_string(),
});
Element {
name: Name {
namespace: Some(XMLNS.to_string()),
local_name: "starttls".to_string(),
},
namespace_declarations,
attributes: HashMap::new(),
content,
}
}
}
impl FromElement for StartTls {
fn from_element(element: peanuts::Element) -> peanuts::Result<Self> {
let Name {
namespace,
local_name,
} = element.name;
if namespace.as_deref() == Some(XMLNS) && &local_name == "starttls" {
let mut required = false;
if element.content.len() == 1 {
match element.content.first().unwrap() {
Content::Element(element) => {
let Name {
namespace,
local_name,
} = &element.name;
if namespace.as_deref() == Some(XMLNS) && local_name == "required" {
required = true
} else {
return Err(peanuts::Error::UnexpectedElement(element.name.clone()));
}
}
c => return Err(peanuts::Error::UnexpectedContent((*c).clone())),
}
} else {
return Err(peanuts::Error::UnexpectedNumberOfContents(
element.content.len(),
));
}
return Ok(StartTls { required });
} else {
return Err(peanuts::Error::IncorrectName(Name {
namespace,
local_name,
}));
}
}
}
#[derive(Debug)]
pub struct Proceed;
impl IntoElement for Proceed {
fn into_element(&self) -> Element {
let mut namespace_declarations = HashSet::new();
namespace_declarations.insert(NamespaceDeclaration {
prefix: None,
namespace: XMLNS.to_string(),
});
Element {
name: Name {
namespace: Some(XMLNS.to_string()),
local_name: "proceed".to_string(),
},
namespace_declarations,
attributes: HashMap::new(),
content: Vec::new(),
}
}
}
impl FromElement for Proceed {
fn from_element(element: Element) -> peanuts::Result<Self> {
let Name {
namespace,
local_name,
} = element.name;
if namespace.as_deref() == Some(XMLNS) && &local_name == "proceed" {
return Ok(Proceed);
} else {
return Err(peanuts::Error::IncorrectName(Name {
namespace,
local_name,
}));
}
}
}
pub struct Failure;
impl IntoElement for Failure {
fn into_element(&self) -> Element {
let mut namespace_declarations = HashSet::new();
namespace_declarations.insert(NamespaceDeclaration {
prefix: None,
namespace: XMLNS.to_string(),
});
Element {
name: Name {
namespace: Some(XMLNS.to_string()),
local_name: "failure".to_string(),
},
namespace_declarations,
attributes: HashMap::new(),
content: Vec::new(),
}
}
}
impl FromElement for Failure {
fn from_element(element: Element) -> peanuts::Result<Self> {
let Name {
namespace,
local_name,
} = element.name;
if namespace.as_deref() == Some(XMLNS) && &local_name == "failure" {
return Ok(Failure);
} else {
return Err(peanuts::Error::IncorrectName(Name {
namespace,
local_name,
}));
}
}
}

View File

@ -6,12 +6,15 @@ use peanuts::{element::Name, Element};
use crate::{Error, JID};
use super::starttls::StartTls;
pub const XMLNS: &str = "http://etherx.jabber.org/streams";
pub const XMLNS_CLIENT: &str = "jabber:client";
// MUST be qualified by stream namespace
// #[derive(XmlSerialize, XmlDeserialize)]
// #[peanuts(xmlns = XMLNS)]
#[derive(Debug)]
pub struct Stream {
pub from: Option<JID>,
to: Option<JID>,
@ -93,7 +96,7 @@ impl IntoElement for Stream {
attributes.insert(
Name {
namespace: None,
local_name: "version".to_string(),
local_name: "id".to_string(),
},
id.clone(),
);
@ -158,3 +161,71 @@ impl<'s> Stream {
}
}
}
#[derive(Debug)]
pub struct Features {
features: Vec<Feature>,
}
impl IntoElement for Features {
fn into_element(&self) -> Element {
let mut content = Vec::new();
for feature in &self.features {
match feature {
Feature::StartTls(start_tls) => {
content.push(Content::Element(start_tls.into_element()))
}
Feature::Sasl => {}
Feature::Bind => {}
Feature::Unknown => {}
}
}
Element {
name: Name {
namespace: Some(XMLNS.to_string()),
local_name: "features".to_string(),
},
namespace_declarations: HashSet::new(),
attributes: HashMap::new(),
content,
}
}
}
impl FromElement for Features {
fn from_element(element: Element) -> peanuts::Result<Self> {
let Name {
namespace,
local_name,
} = element.name;
if namespace.as_deref() == Some(XMLNS) && &local_name == "features" {
let mut features = Vec::new();
for feature in element.content {
match feature {
Content::Element(element) => {
if let Ok(start_tls) = FromElement::from_element(element) {
features.push(Feature::StartTls(start_tls))
} else {
features.push(Feature::Unknown)
}
}
c => return Err(peanuts::Error::UnexpectedContent(c.clone())),
}
}
return Ok(Self { features });
} else {
return Err(peanuts::Error::IncorrectName(Name {
namespace,
local_name,
}));
}
}
}
#[derive(Debug)]
pub enum Feature {
StartTls(StartTls),
Sasl,
Bind,
Unknown,
}