diff --git a/Cargo.lock b/Cargo.lock index 3c32109..824459a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1351,6 +1351,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "base64", "config_struct", "desec", "enum-display", @@ -1358,6 +1359,7 @@ dependencies = [ "log", "pretty_env_logger", "quick-xml", + "rand", "rcgen", "rsdns", "rustls-pemfile", diff --git a/salut/Cargo.toml b/salut/Cargo.toml index 9c874cd..98f7611 100644 --- a/salut/Cargo.toml +++ b/salut/Cargo.toml @@ -22,3 +22,5 @@ desec = { path = "../desec" } rcgen = "0.11.1" rsdns = { version = "0.15.0", features = ["net-tokio"] } rustls-pemfile = "1.0.3" +rand = "0.8.5" +base64 = "0.21.2" diff --git a/salut/src/common.rs b/salut/src/common.rs index 0da669f..67ef970 100644 --- a/salut/src/common.rs +++ b/salut/src/common.rs @@ -1,16 +1,12 @@ -use log::info; +use base64::Engine; +use log::{debug, info}; use quick_xml::{ events::{attributes::Attributes, BytesDecl, BytesEnd, BytesStart, Event}, Writer, }; -use rustls_pemfile::Item; use tokio::io::AsyncWrite; -use crate::{ - error::StreamError, - feature::Feature, - tag::{self, Tag}, -}; +use crate::{error::StreamError, tag}; type Result = std::result::Result; @@ -35,20 +31,33 @@ pub async fn error(writer: W, err: StreamError) -> Result Ok(()) } -pub async fn write_stream_header(writer: W, req: StreamAttrs) -> Result<()> { +pub async fn write_stream_header( + writer: W, + req: StreamHeader, +) -> Result<()> { let mut writer = Writer::new(writer); writer .write_event_async(Event::Decl(BytesDecl::new("1.0", Some("utf-8"), None))) .await?; + let id = req.id.unwrap_or_default(); + let (from, to) = (req.from.unwrap_or_default(), req.to.unwrap_or_default()); + + let mut attrs = vec![ + ("id", id.as_str()), + ("xmlns:stream", "http://etherx.jabber.org/streams"), + ("xmlns", "jabber:client"), + ("xml:lang", "en"), + ("version", "1.0"), + ]; + if !to.is_empty() { + attrs.push(("to", to.as_str())); + } + if !from.is_empty() { + attrs.push(("from", from.as_str())); + } writer .write_event_async(Event::Start( - BytesStart::new("stream:stream").with_attributes(vec![ - ("from", req.from.as_str()), - ("to", req.to.as_str()), - ("xmlns:stream", "http://etherx.jabber.org/streams"), - ("xml:lang", "en"), - ("version", "1.0"), - ]), + BytesStart::new("stream:stream").with_attributes(attrs), )) .await?; @@ -56,13 +65,14 @@ pub async fn write_stream_header(writer: W, req: StreamAt } #[derive(Debug, Clone)] -pub struct StreamAttrs { - pub from: String, - pub to: String, +pub struct StreamHeader { + pub from: Option, + pub to: Option, + pub id: Option, pub namespace: XMLNamespace, } -impl TryFrom> for StreamAttrs { +impl TryFrom> for StreamHeader { type Error = StreamError; fn try_from(value: Attributes<'_>) -> std::result::Result { @@ -86,15 +96,18 @@ impl TryFrom> for StreamAttrs { }, other => { info!( - "ignoring key {}", - String::from_utf8(other.to_vec()).unwrap_or_default() + "ignoring key {} with value << {} >>", + String::from_utf8(other.to_vec()).unwrap_or_default(), + String::from_utf8(v.value.to_vec()).unwrap_or_default(), ); } } } - Ok(StreamAttrs { - from: from.ok_or(StreamError::InvalidFrom)?, - to: to.ok_or(StreamError::HostUnknown)?, + debug!("from: {from:?}; to: {to:?}"); + Ok(StreamHeader { + from: from, + to: Some(to.ok_or(StreamError::HostUnknown)?), + id: None, namespace: ns.ok_or(StreamError::BadNamespacePrefix)?, }) } @@ -104,3 +117,11 @@ impl TryFrom> for StreamAttrs { pub enum XMLNamespace { JabberClient, } + +pub fn gen_id() -> String { + let buf: [u8; 16] = rand::random(); + let mut output = String::new(); + base64::engine::general_purpose::STANDARD_NO_PAD.encode_string(buf, &mut output); + + output +} diff --git a/salut/src/streamstart.rs b/salut/src/streamstart.rs index b946841..0fda76f 100644 --- a/salut/src/streamstart.rs +++ b/salut/src/streamstart.rs @@ -1,6 +1,6 @@ use std::{net::SocketAddr, sync::Arc}; -use log::{error, info}; +use log::{debug, error, info}; use quick_xml::{ events::{BytesEnd, BytesStart, Event}, Reader, Writer, @@ -12,7 +12,7 @@ use tokio::{ use tokio_rustls::rustls; use crate::{ - common::{self, StreamAttrs, XMLNamespace}, + common::{self, StreamHeader, XMLNamespace}, error::StreamError, feature::Feature, negotiator::{self, TcpConnOrTLS}, @@ -33,6 +33,7 @@ type Result = std::result::Result; const FEATURES: &'static [Feature] = &[Feature::start_tls(true)]; pub struct StreamStart { + id: String, reader: Reader>>, writer: Writer>, buffer: Vec, @@ -53,6 +54,7 @@ impl StreamStart { writer, hostname, tls_config, + id: common::gen_id(), buffer: vec![], } } @@ -90,7 +92,7 @@ impl StreamStart { match event { Event::Start(start) => { if start.name().as_ref() == tag::STREAM_ELEMENT_NAME { - let attrs: StreamAttrs = match start.attributes().try_into() { + let attrs: StreamHeader = match start.attributes().try_into() { Ok(a) => a, Err(err) => return StartTLSResult::Failure(self, err.into()), }; @@ -110,9 +112,10 @@ impl StreamStart { info!("starting negotiation with: {attrs:?}"); if let Err(err) = common::write_stream_header( self.writer.get_mut(), - StreamAttrs { - from: attrs.to.clone(), + StreamHeader { + from: Some(self.hostname.clone()), to: attrs.from, + id: Some(self.id.clone()), namespace: XMLNamespace::JabberClient, }, ) @@ -120,9 +123,10 @@ impl StreamStart { { return StartTLSResult::Failure(self, err); }; - if attrs.to != self.hostname { + if attrs.to.unwrap_or_default() != self.hostname { return StartTLSResult::Failure(self, StreamError::HostUnknown); } + debug!("sending features"); if let Err(err) = self.send_features().await { return StartTLSResult::Failure(self, err); }; diff --git a/salut/src/tls/stream.rs b/salut/src/tls/stream.rs index 7bd627b..87acbba 100644 --- a/salut/src/tls/stream.rs +++ b/salut/src/tls/stream.rs @@ -1,15 +1,28 @@ -use quick_xml::{Reader, Writer}; +use log::{error, info}; +use quick_xml::{ + events::{BytesEnd, BytesStart, Event}, + Reader, Writer, +}; use tokio::{ - io::{BufReader, ReadHalf, WriteHalf}, + io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf}, net::TcpStream, }; use tokio_rustls::server::TlsStream; +use crate::{ + common::{self, StreamHeader}, + error::StreamError, + feature::Feature, + tag::{self, Tag}, +}; +type Result = std::result::Result; pub struct TLSStream { + id: String, reader: Reader>>>, writer: Writer>>, buffer: Vec, hostname: String, + features: Vec, } impl TLSStream { @@ -24,9 +37,89 @@ impl TLSStream { reader, writer, hostname, + id: common::gen_id(), buffer: vec![], + features: vec![], } } - pub async fn start_stream(self) {} + pub async fn start_stream(mut self) { + if let Err(err) = self.handle_stream().await { + if let Err(err2) = common::error(self.writer.get_mut(), err).await { + error!("error writing error: {err2}"); + return; + } else { + info!("wrote error {err}") + } + + if let Err(e) = self.writer.get_mut().write_all(b"").await { + error!("writing end to stream: {e}") + } + if let Err(e) = self.writer.get_mut().shutdown().await { + error!("shutting down stream: {e}") + } + } + } + + async fn handle_stream(&mut self) -> Result<()> { + let header = self.get_stream_header().await?; + if !header + .to + .unwrap_or_default() + .eq_ignore_ascii_case(&self.hostname) + { + return Err(StreamError::HostUnknown); + } + common::write_stream_header( + self.writer.get_mut(), + StreamHeader { + from: Some(self.hostname.clone()), + to: header.from, + id: Some(self.id.clone()), + namespace: common::XMLNamespace::JabberClient, + }, + ) + .await?; + self.send_features().await?; + self.negotiate_features().await?; + + todo!() + } + + async fn negotiate_features(&mut self) -> Result<()> { + let available_features: Vec<&Feature> = (&self.features).into_iter().collect(); + while available_features.len() > 0 { + todo!() + } + Ok(()) + } + + async fn send_features(&mut self) -> Result<()> { + self.writer + .write_event_async(Event::Start(BytesStart::new(tag::FEATURE))) + .await?; + for feature in &self.features { + feature.write_tag(self.writer.get_mut()).await?; + } + self.writer + .write_event_async(Event::End(BytesEnd::new(tag::FEATURE))) + .await?; + Ok(()) + } + + async fn get_stream_header(&mut self) -> Result { + loop { + match self.reader.read_event_into_async(&mut self.buffer).await? { + Event::Start(start) => { + if start.name().as_ref() == tag::STREAM_ELEMENT_NAME { + break Ok(start.attributes().try_into()?); + } else { + break Err(StreamError::BadFormat); + } + } + Event::Decl(_) => continue, + _ => break Err(StreamError::BadFormat), + } + } + } }