move some stream init calls to a common module
This commit is contained in:
parent
e7cf44efe1
commit
1ff3cbc281
|
@ -0,0 +1,106 @@
|
||||||
|
use log::info;
|
||||||
|
use quick_xml::{
|
||||||
|
events::{attributes::Attributes, BytesDecl, BytesEnd, BytesStart, Event},
|
||||||
|
Writer,
|
||||||
|
};
|
||||||
|
use rustls_pemfile::Item;
|
||||||
|
use tokio::io::AsyncWrite;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
error::StreamError,
|
||||||
|
feature::Feature,
|
||||||
|
tag::{self, Tag},
|
||||||
|
};
|
||||||
|
|
||||||
|
type Result<T> = std::result::Result<T, StreamError>;
|
||||||
|
|
||||||
|
pub async fn error<W: AsyncWrite + Unpin>(writer: W, err: StreamError) -> Result<()> {
|
||||||
|
let mut writer = Writer::new(writer);
|
||||||
|
let err = err.to_string();
|
||||||
|
writer
|
||||||
|
.write_event_async(Event::Start(BytesStart::new(tag::ERROR_ELEMENT)))
|
||||||
|
.await?;
|
||||||
|
writer
|
||||||
|
.write_event_async(Event::Start(
|
||||||
|
BytesStart::new(&err)
|
||||||
|
.with_attributes(vec![("xmlns", "urn:ietf:params:xml:ns:xmpp-streams")]),
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
writer
|
||||||
|
.write_event_async(Event::End(BytesEnd::new(&err)))
|
||||||
|
.await?;
|
||||||
|
writer
|
||||||
|
.write_event_async(Event::End(BytesEnd::new(tag::ERROR_ELEMENT)))
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write_stream_header<W: AsyncWrite + Unpin>(writer: W, req: StreamAttrs) -> Result<()> {
|
||||||
|
let mut writer = Writer::new(writer);
|
||||||
|
writer
|
||||||
|
.write_event_async(Event::Decl(BytesDecl::new("1.0", Some("utf-8"), None)))
|
||||||
|
.await?;
|
||||||
|
writer
|
||||||
|
.write_event_async(Event::Start(
|
||||||
|
BytesStart::new("stream:stream").with_attributes(vec![
|
||||||
|
("from", req.from.as_str()),
|
||||||
|
("to", req.to.as_str()),
|
||||||
|
("xmlns:stream", "http://etherx.jabber.org/streams"),
|
||||||
|
("xml:lang", "en"),
|
||||||
|
("version", "1.0"),
|
||||||
|
]),
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StreamAttrs {
|
||||||
|
pub from: String,
|
||||||
|
pub to: String,
|
||||||
|
pub namespace: XMLNamespace,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<Attributes<'_>> for StreamAttrs {
|
||||||
|
type Error = StreamError;
|
||||||
|
|
||||||
|
fn try_from(value: Attributes<'_>) -> std::result::Result<Self, Self::Error> {
|
||||||
|
let mut from: Option<String> = None;
|
||||||
|
let mut to: Option<String> = None;
|
||||||
|
let mut ns: Option<XMLNamespace> = None;
|
||||||
|
for v in value {
|
||||||
|
let v = v?;
|
||||||
|
match v.key.local_name().into_inner() {
|
||||||
|
b"from" => {
|
||||||
|
from = Some(String::from_utf8(v.value.to_vec())?);
|
||||||
|
}
|
||||||
|
b"to" => {
|
||||||
|
to = Some(String::from_utf8(v.value.to_vec())?);
|
||||||
|
}
|
||||||
|
b"xmlns" => match v.value.to_vec().as_slice() {
|
||||||
|
b"jabber:client" => {
|
||||||
|
ns = Some(XMLNamespace::JabberClient);
|
||||||
|
}
|
||||||
|
_ => return Err(StreamError::InvalidNamespace),
|
||||||
|
},
|
||||||
|
other => {
|
||||||
|
info!(
|
||||||
|
"ignoring key {}",
|
||||||
|
String::from_utf8(other.to_vec()).unwrap_or_default()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(StreamAttrs {
|
||||||
|
from: from.ok_or(StreamError::InvalidFrom)?,
|
||||||
|
to: to.ok_or(StreamError::HostUnknown)?,
|
||||||
|
namespace: ns.ok_or(StreamError::BadNamespacePrefix)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum XMLNamespace {
|
||||||
|
JabberClient,
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ use std::{process, sync::Arc};
|
||||||
|
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
|
|
||||||
|
mod common;
|
||||||
mod config;
|
mod config;
|
||||||
mod error;
|
mod error;
|
||||||
mod feature;
|
mod feature;
|
||||||
|
|
|
@ -10,9 +10,7 @@ use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::StreamError,
|
error::StreamError,
|
||||||
streamstart::{StartTLSResult, StreamStart},
|
|
||||||
tag::{self, Tag},
|
tag::{self, Tag},
|
||||||
tls::stream::{self, TLSStream},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub enum Step {
|
pub enum Step {
|
||||||
|
|
|
@ -2,19 +2,20 @@ use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
use quick_xml::{
|
use quick_xml::{
|
||||||
events::{attributes::Attributes, BytesDecl, BytesEnd, BytesStart, Event},
|
events::{BytesEnd, BytesStart, Event},
|
||||||
Reader, Writer,
|
Reader, Writer,
|
||||||
};
|
};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
|
io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
|
||||||
net::TcpStream,
|
net::TcpStream,
|
||||||
};
|
};
|
||||||
use tokio_rustls::rustls;
|
use tokio_rustls::rustls;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
common::{self, StreamAttrs, XMLNamespace},
|
||||||
error::StreamError,
|
error::StreamError,
|
||||||
feature::Feature,
|
feature::Feature,
|
||||||
negotiator::{self, Step, TcpConnOrTLS},
|
negotiator::{self, TcpConnOrTLS},
|
||||||
tag::{self, Tag},
|
tag::{self, Tag},
|
||||||
tls::stream,
|
tls::stream,
|
||||||
};
|
};
|
||||||
|
@ -60,7 +61,7 @@ impl StreamStart {
|
||||||
match self.negotiate_stream().await {
|
match self.negotiate_stream().await {
|
||||||
StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await,
|
StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await,
|
||||||
StartTLSResult::Failure(mut conn, err) => {
|
StartTLSResult::Failure(mut conn, err) => {
|
||||||
if let Err(err2) = error(conn.writer.get_mut(), err).await {
|
if let Err(err2) = common::error(conn.writer.get_mut(), err).await {
|
||||||
error!("error writing error: {err2}");
|
error!("error writing error: {err2}");
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
|
@ -107,12 +108,14 @@ impl StreamStart {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
info!("starting negotiation with: {attrs:?}");
|
info!("starting negotiation with: {attrs:?}");
|
||||||
if let Err(err) = self
|
if let Err(err) = common::write_stream_header(
|
||||||
.write_stream_header(StreamAttrs {
|
self.writer.get_mut(),
|
||||||
|
StreamAttrs {
|
||||||
from: attrs.to.clone(),
|
from: attrs.to.clone(),
|
||||||
to: attrs.from,
|
to: attrs.from,
|
||||||
namespace: XMLNamespace::JabberClient,
|
namespace: XMLNamespace::JabberClient,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
return StartTLSResult::Failure(self, err);
|
return StartTLSResult::Failure(self, err);
|
||||||
|
@ -161,26 +164,6 @@ impl StreamStart {
|
||||||
_ => continue,
|
_ => continue,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StartTLSResult::Failure(self, StreamError::InternalServerError)
|
|
||||||
}
|
|
||||||
async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> {
|
|
||||||
self.writer
|
|
||||||
.write_event_async(Event::Decl(BytesDecl::new("1.0", Some("utf-8"), None)))
|
|
||||||
.await?;
|
|
||||||
self.writer
|
|
||||||
.write_event_async(Event::Start(
|
|
||||||
BytesStart::new("stream:stream").with_attributes(vec![
|
|
||||||
("from", req.from.as_str()),
|
|
||||||
("to", req.to.as_str()),
|
|
||||||
("xmlns:stream", "http://etherx.jabber.org/streams"),
|
|
||||||
("xml:lang", "en"),
|
|
||||||
("version", "1.0"),
|
|
||||||
]),
|
|
||||||
))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_features(&mut self) -> Result<()> {
|
async fn send_features(&mut self) -> Result<()> {
|
||||||
|
@ -199,7 +182,7 @@ impl StreamStart {
|
||||||
|
|
||||||
pub fn spawn(
|
pub fn spawn(
|
||||||
hostname: String,
|
hostname: String,
|
||||||
(mut stream, _): (TcpStream, SocketAddr),
|
(stream, _): (TcpStream, SocketAddr),
|
||||||
tls_config: Arc<rustls::ServerConfig>,
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
) {
|
) {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
@ -208,74 +191,3 @@ pub fn spawn(
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn error<W: AsyncWrite + Unpin>(writer: W, err: StreamError) -> Result<()> {
|
|
||||||
let mut writer = Writer::new(writer);
|
|
||||||
let err = err.to_string();
|
|
||||||
writer
|
|
||||||
.write_event_async(Event::Start(BytesStart::new(tag::ERROR_ELEMENT)))
|
|
||||||
.await?;
|
|
||||||
writer
|
|
||||||
.write_event_async(Event::Start(
|
|
||||||
BytesStart::new(&err)
|
|
||||||
.with_attributes(vec![("xmlns", "urn:ietf:params:xml:ns:xmpp-streams")]),
|
|
||||||
))
|
|
||||||
.await?;
|
|
||||||
writer
|
|
||||||
.write_event_async(Event::End(BytesEnd::new(&err)))
|
|
||||||
.await?;
|
|
||||||
writer
|
|
||||||
.write_event_async(Event::End(BytesEnd::new(tag::ERROR_ELEMENT)))
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct StreamAttrs {
|
|
||||||
from: String,
|
|
||||||
to: String,
|
|
||||||
namespace: XMLNamespace,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<Attributes<'_>> for StreamAttrs {
|
|
||||||
type Error = StreamError;
|
|
||||||
|
|
||||||
fn try_from(value: Attributes<'_>) -> std::result::Result<Self, Self::Error> {
|
|
||||||
let mut from: Option<String> = None;
|
|
||||||
let mut to: Option<String> = None;
|
|
||||||
let mut ns: Option<XMLNamespace> = None;
|
|
||||||
for v in value {
|
|
||||||
let v = v?;
|
|
||||||
match v.key.local_name().into_inner() {
|
|
||||||
b"from" => {
|
|
||||||
from = Some(String::from_utf8(v.value.to_vec())?);
|
|
||||||
}
|
|
||||||
b"to" => {
|
|
||||||
to = Some(String::from_utf8(v.value.to_vec())?);
|
|
||||||
}
|
|
||||||
b"xmlns" => match v.value.to_vec().as_slice() {
|
|
||||||
b"jabber:client" => {
|
|
||||||
ns = Some(XMLNamespace::JabberClient);
|
|
||||||
}
|
|
||||||
_ => return Err(StreamError::InvalidNamespace),
|
|
||||||
},
|
|
||||||
other => {
|
|
||||||
info!(
|
|
||||||
"ignoring key {}",
|
|
||||||
String::from_utf8(other.to_vec()).unwrap_or_default()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(StreamAttrs {
|
|
||||||
from: from.ok_or(StreamError::InvalidFrom)?,
|
|
||||||
to: to.ok_or(StreamError::HostUnknown)?,
|
|
||||||
namespace: ns.ok_or(StreamError::BadNamespacePrefix)?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
enum XMLNamespace {
|
|
||||||
JabberClient,
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue