move some stream init calls to a common module

This commit is contained in:
emilis 2023-07-01 09:30:25 +01:00
parent e7cf44efe1
commit 1ff3cbc281
4 changed files with 119 additions and 102 deletions

106
salut/src/common.rs Normal file
View File

@ -0,0 +1,106 @@
use log::info;
use quick_xml::{
events::{attributes::Attributes, BytesDecl, BytesEnd, BytesStart, Event},
Writer,
};
use rustls_pemfile::Item;
use tokio::io::AsyncWrite;
use crate::{
error::StreamError,
feature::Feature,
tag::{self, Tag},
};
type Result<T> = std::result::Result<T, StreamError>;
pub async fn error<W: AsyncWrite + Unpin>(writer: W, err: StreamError) -> Result<()> {
let mut writer = Writer::new(writer);
let err = err.to_string();
writer
.write_event_async(Event::Start(BytesStart::new(tag::ERROR_ELEMENT)))
.await?;
writer
.write_event_async(Event::Start(
BytesStart::new(&err)
.with_attributes(vec![("xmlns", "urn:ietf:params:xml:ns:xmpp-streams")]),
))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new(&err)))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new(tag::ERROR_ELEMENT)))
.await?;
Ok(())
}
pub async fn write_stream_header<W: AsyncWrite + Unpin>(writer: W, req: StreamAttrs) -> Result<()> {
let mut writer = Writer::new(writer);
writer
.write_event_async(Event::Decl(BytesDecl::new("1.0", Some("utf-8"), None)))
.await?;
writer
.write_event_async(Event::Start(
BytesStart::new("stream:stream").with_attributes(vec![
("from", req.from.as_str()),
("to", req.to.as_str()),
("xmlns:stream", "http://etherx.jabber.org/streams"),
("xml:lang", "en"),
("version", "1.0"),
]),
))
.await?;
Ok(())
}
#[derive(Debug, Clone)]
pub struct StreamAttrs {
pub from: String,
pub to: String,
pub namespace: XMLNamespace,
}
impl TryFrom<Attributes<'_>> for StreamAttrs {
type Error = StreamError;
fn try_from(value: Attributes<'_>) -> std::result::Result<Self, Self::Error> {
let mut from: Option<String> = None;
let mut to: Option<String> = None;
let mut ns: Option<XMLNamespace> = None;
for v in value {
let v = v?;
match v.key.local_name().into_inner() {
b"from" => {
from = Some(String::from_utf8(v.value.to_vec())?);
}
b"to" => {
to = Some(String::from_utf8(v.value.to_vec())?);
}
b"xmlns" => match v.value.to_vec().as_slice() {
b"jabber:client" => {
ns = Some(XMLNamespace::JabberClient);
}
_ => return Err(StreamError::InvalidNamespace),
},
other => {
info!(
"ignoring key {}",
String::from_utf8(other.to_vec()).unwrap_or_default()
);
}
}
}
Ok(StreamAttrs {
from: from.ok_or(StreamError::InvalidFrom)?,
to: to.ok_or(StreamError::HostUnknown)?,
namespace: ns.ok_or(StreamError::BadNamespacePrefix)?,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum XMLNamespace {
JabberClient,
}

View File

@ -2,6 +2,7 @@ use std::{process, sync::Arc};
use log::{error, info};
mod common;
mod config;
mod error;
mod feature;

View File

@ -10,9 +10,7 @@ use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
use crate::{
error::StreamError,
streamstart::{StartTLSResult, StreamStart},
tag::{self, Tag},
tls::stream::{self, TLSStream},
};
pub enum Step {

View File

@ -2,19 +2,20 @@ use std::{net::SocketAddr, sync::Arc};
use log::{error, info};
use quick_xml::{
events::{attributes::Attributes, BytesDecl, BytesEnd, BytesStart, Event},
events::{BytesEnd, BytesStart, Event},
Reader, Writer,
};
use tokio::{
io::{AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
net::TcpStream,
};
use tokio_rustls::rustls;
use crate::{
common::{self, StreamAttrs, XMLNamespace},
error::StreamError,
feature::Feature,
negotiator::{self, Step, TcpConnOrTLS},
negotiator::{self, TcpConnOrTLS},
tag::{self, Tag},
tls::stream,
};
@ -60,7 +61,7 @@ impl StreamStart {
match self.negotiate_stream().await {
StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await,
StartTLSResult::Failure(mut conn, err) => {
if let Err(err2) = error(conn.writer.get_mut(), err).await {
if let Err(err2) = common::error(conn.writer.get_mut(), err).await {
error!("error writing error: {err2}");
return;
} else {
@ -107,13 +108,15 @@ impl StreamStart {
}
};
info!("starting negotiation with: {attrs:?}");
if let Err(err) = self
.write_stream_header(StreamAttrs {
if let Err(err) = common::write_stream_header(
self.writer.get_mut(),
StreamAttrs {
from: attrs.to.clone(),
to: attrs.from,
namespace: XMLNamespace::JabberClient,
})
.await
},
)
.await
{
return StartTLSResult::Failure(self, err);
};
@ -161,26 +164,6 @@ impl StreamStart {
_ => continue,
}
}
StartTLSResult::Failure(self, StreamError::InternalServerError)
}
async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> {
self.writer
.write_event_async(Event::Decl(BytesDecl::new("1.0", Some("utf-8"), None)))
.await?;
self.writer
.write_event_async(Event::Start(
BytesStart::new("stream:stream").with_attributes(vec![
("from", req.from.as_str()),
("to", req.to.as_str()),
("xmlns:stream", "http://etherx.jabber.org/streams"),
("xml:lang", "en"),
("version", "1.0"),
]),
))
.await?;
Ok(())
}
async fn send_features(&mut self) -> Result<()> {
@ -199,7 +182,7 @@ impl StreamStart {
pub fn spawn(
hostname: String,
(mut stream, _): (TcpStream, SocketAddr),
(stream, _): (TcpStream, SocketAddr),
tls_config: Arc<rustls::ServerConfig>,
) {
tokio::spawn(async move {
@ -208,74 +191,3 @@ pub fn spawn(
.await;
});
}
async fn error<W: AsyncWrite + Unpin>(writer: W, err: StreamError) -> Result<()> {
let mut writer = Writer::new(writer);
let err = err.to_string();
writer
.write_event_async(Event::Start(BytesStart::new(tag::ERROR_ELEMENT)))
.await?;
writer
.write_event_async(Event::Start(
BytesStart::new(&err)
.with_attributes(vec![("xmlns", "urn:ietf:params:xml:ns:xmpp-streams")]),
))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new(&err)))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new(tag::ERROR_ELEMENT)))
.await?;
Ok(())
}
#[derive(Debug, Clone)]
struct StreamAttrs {
from: String,
to: String,
namespace: XMLNamespace,
}
impl TryFrom<Attributes<'_>> for StreamAttrs {
type Error = StreamError;
fn try_from(value: Attributes<'_>) -> std::result::Result<Self, Self::Error> {
let mut from: Option<String> = None;
let mut to: Option<String> = None;
let mut ns: Option<XMLNamespace> = None;
for v in value {
let v = v?;
match v.key.local_name().into_inner() {
b"from" => {
from = Some(String::from_utf8(v.value.to_vec())?);
}
b"to" => {
to = Some(String::from_utf8(v.value.to_vec())?);
}
b"xmlns" => match v.value.to_vec().as_slice() {
b"jabber:client" => {
ns = Some(XMLNamespace::JabberClient);
}
_ => return Err(StreamError::InvalidNamespace),
},
other => {
info!(
"ignoring key {}",
String::from_utf8(other.to_vec()).unwrap_or_default()
);
}
}
}
Ok(StreamAttrs {
from: from.ok_or(StreamError::InvalidFrom)?,
to: to.ok_or(StreamError::HostUnknown)?,
namespace: ns.ok_or(StreamError::BadNamespacePrefix)?,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum XMLNamespace {
JabberClient,
}