diff --git a/src/connection.rs b/src/connection.rs index 89f382f..2b70747 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -27,8 +27,11 @@ impl Connection { match self { Connection::Encrypted(j) => Ok(j), Connection::Unencrypted(mut j) => { + j.start_stream().await?; info!("upgrading connection to tls"); - Ok(j.starttls().await?) + j.get_features().await?; + let j = j.starttls().await?; + Ok(j) } } } @@ -179,4 +182,14 @@ mod tests { async fn connect() { Connection::connect("blos.sm").await.unwrap(); } + + #[test(tokio::test)] + async fn test_tls() { + Connection::connect("blos.sm") + .await + .unwrap() + .ensure_tls() + .await + .unwrap(); + } } diff --git a/src/jabber.rs b/src/jabber.rs index afe840b..87a2b44 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,14 +1,18 @@ use std::str; use std::sync::Arc; +use peanuts::element::{FromElement, IntoElement}; use peanuts::{Reader, Writer}; use rsasl::prelude::SASLConfig; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio_native_tls::native_tls::TlsConnector; use tracing::{debug, info, trace}; +use trust_dns_resolver::proto::rr::domain::IntoLabel; use crate::connection::{Tls, Unencrypted}; use crate::error::Error; -use crate::stanza::stream::Stream; +use crate::stanza::starttls::{Proceed, StartTls}; +use crate::stanza::stream::{Features, Stream}; use crate::stanza::XML_VERSION; use crate::Result; use crate::JID; @@ -62,7 +66,6 @@ where // opening stream element let server = self.server.clone().try_into()?; let stream = Stream::new_client(None, server, None, "en".to_string()); - // TODO: nicer function to serialize to xml writer self.writer.write_start(&stream).await?; // server to client @@ -72,57 +75,53 @@ where // receive stream element and validate let stream: Stream = self.reader.read_start().await?; + debug!("got stream: {:?}", stream); if let Some(from) = stream.from { self.server = from.to_string() } Ok(()) } + + pub async fn get_features(&mut self) -> Result { + debug!("getting features"); + let features: Features = self.reader.read().await?; + debug!("got features: {:?}", features); + Ok(features) + } + + pub fn into_inner(self) -> S { + self.reader.into_inner().unsplit(self.writer.into_inner()) + } } -// pub async fn get_features(&mut self) -> Result> { -// Element::read(&mut self.reader).await?.try_into() -// } - impl Jabber { - pub async fn starttls(&mut self) -> Result> { - todo!() + pub async fn starttls(mut self) -> Result> { + self.writer + .write_full(&StartTls { required: false }) + .await?; + let proceed: Proceed = self.reader.read().await?; + debug!("got proceed: {:?}", proceed); + let connector = TlsConnector::new().unwrap(); + let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); + if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) + .connect(&self.server, stream) + .await + { + let (read, write) = tokio::io::split(tlsstream); + let mut client = Jabber::new( + read, + write, + self.jid.to_owned(), + self.auth.to_owned(), + self.server.to_owned(), + ); + client.start_stream().await?; + return Ok(client); + } else { + return Err(Error::Connection); + } } - // let mut starttls_element = BytesStart::new("starttls"); - // starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); - // self.writer - // .write_event_async(Event::Empty(starttls_element)) - // .await - // .unwrap(); - // let mut buf = Vec::new(); - // match self.reader.read_event_into_async(&mut buf).await.unwrap() { - // Event::Empty(e) => match e.name() { - // QName(b"proceed") => { - // let connector = TlsConnector::new().unwrap(); - // let stream = self - // .reader - // .into_inner() - // .into_inner() - // .unsplit(self.writer.into_inner()); - // if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) - // .connect(&self.jabber.server, stream) - // .await - // { - // let (read, write) = tokio::io::split(tlsstream); - // let reader = Reader::from_reader(BufReader::new(read)); - // let writer = Writer::new(write); - // let mut client = - // super::encrypted::JabberClient::new(reader, writer, self.jabber); - // client.start_stream().await?; - // return Ok(client); - // } - // } - // QName(_) => return Err(JabberError::TlsNegotiation), - // }, - // _ => return Err(JabberError::TlsNegotiation), - // } - // Err(JabberError::TlsNegotiation) - // } } impl std::fmt::Debug for Jabber { diff --git a/src/lib.rs b/src/lib.rs index 306b0fd..88b91a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,9 +8,6 @@ pub mod jabber; pub mod jid; pub mod stanza; -#[macro_use] -extern crate lazy_static; - pub use connection::Connection; pub use error::Error; pub use jabber::Jabber; diff --git a/src/stanza/starttls.rs b/src/stanza/starttls.rs index 8b13789..874ae66 100644 --- a/src/stanza/starttls.rs +++ b/src/stanza/starttls.rs @@ -1 +1,163 @@ +use std::collections::{HashMap, HashSet}; +use peanuts::{ + element::{Content, FromElement, IntoElement, Name, NamespaceDeclaration}, + Element, +}; + +pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; + +#[derive(Debug)] +pub struct StartTls { + pub required: bool, +} + +impl IntoElement for StartTls { + fn into_element(&self) -> peanuts::Element { + let content; + if self.required == true { + let element = Content::Element(Element { + name: Name { + namespace: Some(XMLNS.to_string()), + local_name: "required".to_string(), + }, + namespace_declarations: HashSet::new(), + attributes: HashMap::new(), + content: Vec::new(), + }); + content = vec![element]; + } else { + content = Vec::new(); + } + let mut namespace_declarations = HashSet::new(); + namespace_declarations.insert(NamespaceDeclaration { + prefix: None, + namespace: XMLNS.to_string(), + }); + Element { + name: Name { + namespace: Some(XMLNS.to_string()), + local_name: "starttls".to_string(), + }, + namespace_declarations, + attributes: HashMap::new(), + content, + } + } +} + +impl FromElement for StartTls { + fn from_element(element: peanuts::Element) -> peanuts::Result { + let Name { + namespace, + local_name, + } = element.name; + if namespace.as_deref() == Some(XMLNS) && &local_name == "starttls" { + let mut required = false; + if element.content.len() == 1 { + match element.content.first().unwrap() { + Content::Element(element) => { + let Name { + namespace, + local_name, + } = &element.name; + + if namespace.as_deref() == Some(XMLNS) && local_name == "required" { + required = true + } else { + return Err(peanuts::Error::UnexpectedElement(element.name.clone())); + } + } + c => return Err(peanuts::Error::UnexpectedContent((*c).clone())), + } + } else { + return Err(peanuts::Error::UnexpectedNumberOfContents( + element.content.len(), + )); + } + return Ok(StartTls { required }); + } else { + return Err(peanuts::Error::IncorrectName(Name { + namespace, + local_name, + })); + } + } +} + +#[derive(Debug)] +pub struct Proceed; + +impl IntoElement for Proceed { + fn into_element(&self) -> Element { + let mut namespace_declarations = HashSet::new(); + namespace_declarations.insert(NamespaceDeclaration { + prefix: None, + namespace: XMLNS.to_string(), + }); + Element { + name: Name { + namespace: Some(XMLNS.to_string()), + local_name: "proceed".to_string(), + }, + namespace_declarations, + attributes: HashMap::new(), + content: Vec::new(), + } + } +} + +impl FromElement for Proceed { + fn from_element(element: Element) -> peanuts::Result { + let Name { + namespace, + local_name, + } = element.name; + if namespace.as_deref() == Some(XMLNS) && &local_name == "proceed" { + return Ok(Proceed); + } else { + return Err(peanuts::Error::IncorrectName(Name { + namespace, + local_name, + })); + } + } +} + +pub struct Failure; + +impl IntoElement for Failure { + fn into_element(&self) -> Element { + let mut namespace_declarations = HashSet::new(); + namespace_declarations.insert(NamespaceDeclaration { + prefix: None, + namespace: XMLNS.to_string(), + }); + Element { + name: Name { + namespace: Some(XMLNS.to_string()), + local_name: "failure".to_string(), + }, + namespace_declarations, + attributes: HashMap::new(), + content: Vec::new(), + } + } +} + +impl FromElement for Failure { + fn from_element(element: Element) -> peanuts::Result { + let Name { + namespace, + local_name, + } = element.name; + if namespace.as_deref() == Some(XMLNS) && &local_name == "failure" { + return Ok(Failure); + } else { + return Err(peanuts::Error::IncorrectName(Name { + namespace, + local_name, + })); + } + } +} diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index ac4badc..4516682 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -6,12 +6,15 @@ use peanuts::{element::Name, Element}; use crate::{Error, JID}; +use super::starttls::StartTls; + pub const XMLNS: &str = "http://etherx.jabber.org/streams"; pub const XMLNS_CLIENT: &str = "jabber:client"; // MUST be qualified by stream namespace // #[derive(XmlSerialize, XmlDeserialize)] // #[peanuts(xmlns = XMLNS)] +#[derive(Debug)] pub struct Stream { pub from: Option, to: Option, @@ -93,7 +96,7 @@ impl IntoElement for Stream { attributes.insert( Name { namespace: None, - local_name: "version".to_string(), + local_name: "id".to_string(), }, id.clone(), ); @@ -158,3 +161,71 @@ impl<'s> Stream { } } } + +#[derive(Debug)] +pub struct Features { + features: Vec, +} + +impl IntoElement for Features { + fn into_element(&self) -> Element { + let mut content = Vec::new(); + for feature in &self.features { + match feature { + Feature::StartTls(start_tls) => { + content.push(Content::Element(start_tls.into_element())) + } + Feature::Sasl => {} + Feature::Bind => {} + Feature::Unknown => {} + } + } + Element { + name: Name { + namespace: Some(XMLNS.to_string()), + local_name: "features".to_string(), + }, + namespace_declarations: HashSet::new(), + attributes: HashMap::new(), + content, + } + } +} + +impl FromElement for Features { + fn from_element(element: Element) -> peanuts::Result { + let Name { + namespace, + local_name, + } = element.name; + if namespace.as_deref() == Some(XMLNS) && &local_name == "features" { + let mut features = Vec::new(); + for feature in element.content { + match feature { + Content::Element(element) => { + if let Ok(start_tls) = FromElement::from_element(element) { + features.push(Feature::StartTls(start_tls)) + } else { + features.push(Feature::Unknown) + } + } + c => return Err(peanuts::Error::UnexpectedContent(c.clone())), + } + } + return Ok(Self { features }); + } else { + return Err(peanuts::Error::IncorrectName(Name { + namespace, + local_name, + })); + } + } +} + +#[derive(Debug)] +pub enum Feature { + StartTls(StartTls), + Sasl, + Bind, + Unknown, +}