diff --git a/salut/src/common.rs b/salut/src/common.rs new file mode 100644 index 0000000..0da669f --- /dev/null +++ b/salut/src/common.rs @@ -0,0 +1,106 @@ +use log::info; +use quick_xml::{ + events::{attributes::Attributes, BytesDecl, BytesEnd, BytesStart, Event}, + Writer, +}; +use rustls_pemfile::Item; +use tokio::io::AsyncWrite; + +use crate::{ + error::StreamError, + feature::Feature, + tag::{self, Tag}, +}; + +type Result = std::result::Result; + +pub async fn error(writer: W, err: StreamError) -> Result<()> { + let mut writer = Writer::new(writer); + let err = err.to_string(); + writer + .write_event_async(Event::Start(BytesStart::new(tag::ERROR_ELEMENT))) + .await?; + writer + .write_event_async(Event::Start( + BytesStart::new(&err) + .with_attributes(vec![("xmlns", "urn:ietf:params:xml:ns:xmpp-streams")]), + )) + .await?; + writer + .write_event_async(Event::End(BytesEnd::new(&err))) + .await?; + writer + .write_event_async(Event::End(BytesEnd::new(tag::ERROR_ELEMENT))) + .await?; + Ok(()) +} + +pub async fn write_stream_header(writer: W, req: StreamAttrs) -> Result<()> { + let mut writer = Writer::new(writer); + writer + .write_event_async(Event::Decl(BytesDecl::new("1.0", Some("utf-8"), None))) + .await?; + writer + .write_event_async(Event::Start( + BytesStart::new("stream:stream").with_attributes(vec![ + ("from", req.from.as_str()), + ("to", req.to.as_str()), + ("xmlns:stream", "http://etherx.jabber.org/streams"), + ("xml:lang", "en"), + ("version", "1.0"), + ]), + )) + .await?; + + Ok(()) +} + +#[derive(Debug, Clone)] +pub struct StreamAttrs { + pub from: String, + pub to: String, + pub namespace: XMLNamespace, +} + +impl TryFrom> for StreamAttrs { + type Error = StreamError; + + fn try_from(value: Attributes<'_>) -> std::result::Result { + let mut from: Option = None; + let mut to: Option = None; + let mut ns: Option = None; + for v in value { + let v = v?; + match v.key.local_name().into_inner() { + b"from" => { + from = Some(String::from_utf8(v.value.to_vec())?); + } + b"to" => { + to = Some(String::from_utf8(v.value.to_vec())?); + } + b"xmlns" => match v.value.to_vec().as_slice() { + b"jabber:client" => { + ns = Some(XMLNamespace::JabberClient); + } + _ => return Err(StreamError::InvalidNamespace), + }, + other => { + info!( + "ignoring key {}", + String::from_utf8(other.to_vec()).unwrap_or_default() + ); + } + } + } + Ok(StreamAttrs { + from: from.ok_or(StreamError::InvalidFrom)?, + to: to.ok_or(StreamError::HostUnknown)?, + namespace: ns.ok_or(StreamError::BadNamespacePrefix)?, + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum XMLNamespace { + JabberClient, +} diff --git a/salut/src/main.rs b/salut/src/main.rs index b349f12..d0e9ac4 100644 --- a/salut/src/main.rs +++ b/salut/src/main.rs @@ -2,6 +2,7 @@ use std::{process, sync::Arc}; use log::{error, info}; +mod common; mod config; mod error; mod feature; diff --git a/salut/src/negotiator.rs b/salut/src/negotiator.rs index 1ed7ad9..6059822 100644 --- a/salut/src/negotiator.rs +++ b/salut/src/negotiator.rs @@ -10,9 +10,7 @@ use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor}; use crate::{ error::StreamError, - streamstart::{StartTLSResult, StreamStart}, tag::{self, Tag}, - tls::stream::{self, TLSStream}, }; pub enum Step { diff --git a/salut/src/streamstart.rs b/salut/src/streamstart.rs index 2702832..b946841 100644 --- a/salut/src/streamstart.rs +++ b/salut/src/streamstart.rs @@ -2,19 +2,20 @@ use std::{net::SocketAddr, sync::Arc}; use log::{error, info}; use quick_xml::{ - events::{attributes::Attributes, BytesDecl, BytesEnd, BytesStart, Event}, + events::{BytesEnd, BytesStart, Event}, Reader, Writer, }; use tokio::{ - io::{AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}, + io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf}, net::TcpStream, }; use tokio_rustls::rustls; use crate::{ + common::{self, StreamAttrs, XMLNamespace}, error::StreamError, feature::Feature, - negotiator::{self, Step, TcpConnOrTLS}, + negotiator::{self, TcpConnOrTLS}, tag::{self, Tag}, tls::stream, }; @@ -60,7 +61,7 @@ impl StreamStart { match self.negotiate_stream().await { StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await, StartTLSResult::Failure(mut conn, err) => { - if let Err(err2) = error(conn.writer.get_mut(), err).await { + if let Err(err2) = common::error(conn.writer.get_mut(), err).await { error!("error writing error: {err2}"); return; } else { @@ -107,13 +108,15 @@ impl StreamStart { } }; info!("starting negotiation with: {attrs:?}"); - if let Err(err) = self - .write_stream_header(StreamAttrs { + if let Err(err) = common::write_stream_header( + self.writer.get_mut(), + StreamAttrs { from: attrs.to.clone(), to: attrs.from, namespace: XMLNamespace::JabberClient, - }) - .await + }, + ) + .await { return StartTLSResult::Failure(self, err); }; @@ -161,26 +164,6 @@ impl StreamStart { _ => continue, } } - - StartTLSResult::Failure(self, StreamError::InternalServerError) - } - async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> { - self.writer - .write_event_async(Event::Decl(BytesDecl::new("1.0", Some("utf-8"), None))) - .await?; - self.writer - .write_event_async(Event::Start( - BytesStart::new("stream:stream").with_attributes(vec![ - ("from", req.from.as_str()), - ("to", req.to.as_str()), - ("xmlns:stream", "http://etherx.jabber.org/streams"), - ("xml:lang", "en"), - ("version", "1.0"), - ]), - )) - .await?; - - Ok(()) } async fn send_features(&mut self) -> Result<()> { @@ -199,7 +182,7 @@ impl StreamStart { pub fn spawn( hostname: String, - (mut stream, _): (TcpStream, SocketAddr), + (stream, _): (TcpStream, SocketAddr), tls_config: Arc, ) { tokio::spawn(async move { @@ -208,74 +191,3 @@ pub fn spawn( .await; }); } - -async fn error(writer: W, err: StreamError) -> Result<()> { - let mut writer = Writer::new(writer); - let err = err.to_string(); - writer - .write_event_async(Event::Start(BytesStart::new(tag::ERROR_ELEMENT))) - .await?; - writer - .write_event_async(Event::Start( - BytesStart::new(&err) - .with_attributes(vec![("xmlns", "urn:ietf:params:xml:ns:xmpp-streams")]), - )) - .await?; - writer - .write_event_async(Event::End(BytesEnd::new(&err))) - .await?; - writer - .write_event_async(Event::End(BytesEnd::new(tag::ERROR_ELEMENT))) - .await?; - Ok(()) -} - -#[derive(Debug, Clone)] -struct StreamAttrs { - from: String, - to: String, - namespace: XMLNamespace, -} - -impl TryFrom> for StreamAttrs { - type Error = StreamError; - - fn try_from(value: Attributes<'_>) -> std::result::Result { - let mut from: Option = None; - let mut to: Option = None; - let mut ns: Option = None; - for v in value { - let v = v?; - match v.key.local_name().into_inner() { - b"from" => { - from = Some(String::from_utf8(v.value.to_vec())?); - } - b"to" => { - to = Some(String::from_utf8(v.value.to_vec())?); - } - b"xmlns" => match v.value.to_vec().as_slice() { - b"jabber:client" => { - ns = Some(XMLNamespace::JabberClient); - } - _ => return Err(StreamError::InvalidNamespace), - }, - other => { - info!( - "ignoring key {}", - String::from_utf8(other.to_vec()).unwrap_or_default() - ); - } - } - } - Ok(StreamAttrs { - from: from.ok_or(StreamError::InvalidFrom)?, - to: to.ok_or(StreamError::HostUnknown)?, - namespace: ns.ok_or(StreamError::BadNamespacePrefix)?, - }) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum XMLNamespace { - JabberClient, -}