implement stream start

This commit is contained in:
cel 🌸 2023-10-28 21:06:42 +01:00
parent c16f299364
commit a1f3cf450b
4 changed files with 74 additions and 70 deletions

View File

@ -15,6 +15,7 @@ quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async
# TODO: remove unneeded features # TODO: remove unneeded features
rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }
serde = "1.0.180" serde = "1.0.180"
serde_with = "3.4.0"
tokio = { version = "1.28", features = ["full"] } tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
tracing = "0.1.40" tracing = "0.1.40"

View File

@ -15,16 +15,21 @@ use crate::Result;
pub type Tls = TlsStream<TcpStream>; pub type Tls = TlsStream<TcpStream>;
pub type Unencrypted = TcpStream; pub type Unencrypted = TcpStream;
#[derive(Debug)]
pub enum Connection { pub enum Connection {
Encrypted(Jabber<Tls>), Encrypted(Jabber<Tls>),
Unencrypted(Jabber<Unencrypted>), Unencrypted(Jabber<Unencrypted>),
} }
impl Connection { impl Connection {
#[instrument]
pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
match self { match self {
Connection::Encrypted(j) => Ok(j), Connection::Encrypted(j) => Ok(j),
Connection::Unencrypted(j) => Ok(j.starttls().await?), Connection::Unencrypted(mut j) => {
info!("upgrading connection to tls");
Ok(j.starttls().await?)
}
} }
} }
@ -36,7 +41,7 @@ impl Connection {
// } // }
#[instrument] #[instrument]
async fn connect(server: &str) -> Result<Self> { pub async fn connect(server: &str) -> Result<Self> {
info!("connecting to {}", server); info!("connecting to {}", server);
let sockets = Self::get_sockets(&server).await; let sockets = Self::get_sockets(&server).await;
debug!("discovered sockets: {:?}", sockets); debug!("discovered sockets: {:?}", sockets);

View File

@ -1,9 +1,11 @@
use std::str;
use std::sync::Arc; use std::sync::Arc;
use quick_xml::{events::Event, se::Serializer, NsReader, Writer}; use quick_xml::{events::Event, se::Serializer, NsReader, Writer};
use rsasl::prelude::SASLConfig; use rsasl::prelude::SASLConfig;
use serde::Serialize; use serde::Serialize;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tracing::{debug, info, trace};
use crate::connection::{Tls, Unencrypted}; use crate::connection::{Tls, Unencrypted};
use crate::error::JabberError; use crate::error::JabberError;
@ -17,7 +19,7 @@ where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
{ {
reader: NsReader<BufReader<ReadHalf<S>>>, reader: NsReader<BufReader<ReadHalf<S>>>,
writer: Writer<WriteHalf<S>>, writer: WriteHalf<S>,
jid: Option<JID>, jid: Option<JID>,
auth: Option<Arc<SASLConfig>>, auth: Option<Arc<SASLConfig>>,
server: String, server: String,
@ -35,7 +37,6 @@ where
server: String, server: String,
) -> Self { ) -> Self {
let reader = NsReader::from_reader(BufReader::new(reader)); let reader = NsReader::from_reader(BufReader::new(reader));
let writer = Writer::new(writer);
Self { Self {
reader, reader,
writer, writer,
@ -49,112 +50,71 @@ where
impl<S> Jabber<S> impl<S> Jabber<S>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
Writer<tokio::io::WriteHalf<S>>: AsyncWriteExt,
Writer<tokio::io::WriteHalf<S>>: AsyncWrite,
{ {
// pub async fn negotiate(self) -> Result<Jabber<S>> {}
pub async fn start_stream(&mut self) -> Result<()> { pub async fn start_stream(&mut self) -> Result<()> {
// client to server // client to server
// declaration // declaration
self.writer.write_event_async(DECLARATION.clone()).await?; let mut xmlwriter = Writer::new(&mut self.writer);
xmlwriter.write_event_async(DECLARATION.clone()).await?;
// opening stream element // opening stream element
let server = &self.server.to_owned().try_into()?; let server = &self.server.to_owned().try_into()?;
let stream_element = Stream::new_client(None, server, None, "en"); let stream_element = Stream::new_client(None, server, None, "en");
// TODO: nicer function to serialize to xml writer // TODO: nicer function to serialize to xml writer
let mut buffer = String::new(); let mut buffer = String::new();
let ser = Serializer::new(&mut buffer); let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name");
stream_element.serialize(ser).unwrap(); stream_element.serialize(ser).unwrap();
self.writer.write_all(buffer.as_bytes()); trace!("sent: {}", buffer);
self.writer.write_all(buffer.as_bytes()).await.unwrap();
// server to client // server to client
// may or may not send a declaration // may or may not send a declaration
let mut buf = Vec::new(); let mut buf = Vec::new();
let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
trace!("received: {:?}", first_event);
match first_event { match first_event {
(quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => { (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => {
if let Ok(version) = e.version() { if let Ok(version) = e.version() {
if version.as_ref() == b"1.0" { if version.as_ref() == b"1.0" {
first_event = self.reader.read_resolved_event_into_async(&mut buf).await? first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
trace!("received: {:?}", first_event);
} else { } else {
// todo: error // todo: error
todo!() todo!()
} }
} else { } else {
first_event = self.reader.read_resolved_event_into_async(&mut buf).await? first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
trace!("received: {:?}", first_event);
} }
} }
_ => (), _ => (),
} }
// receive stream element and validate // receive stream element and validate
let stream_response: Stream;
match first_event { match first_event {
(quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => { (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => {
if ns.0 == crate::stanza::stream::XMLNS.as_bytes() { if ns.0 == crate::stanza::stream::XMLNS.as_bytes() {
// stream_response = Stream::new( e.attributes().try_for_each(|attr| -> Result<()> {
// e.try_get_attribute("from")?.try_map(|attribute| { let attr = attr?;
// str::from_utf8(attribute.value.as_ref())? match attr.key.into_inner() {
// .try_into()? b"from" => {
// .as_ref() self.server = str::from_utf8(&attr.value)?.to_owned();
// })?, Ok(())
// e.try_get_attribute("to")?.try_map(|attribute| { }
// str::from_utf8(attribute.value.as_ref())? _ => Ok(()),
// .try_into()? }
// .as_ref() });
// })?,
// e.try_get_attribute("id")?.try_map(|attribute| {
// str::from_utf8(attribute.value.as_ref())?
// .try_into()?
// .as_ref()
// })?,
// e.try_get_attribute("version")?.try_map(|attribute| {
// str::from_utf8(attribute.value.as_ref())?
// .try_into()?
// .as_ref()
// })?,
// e.try_get_attribute("lang")?.try_map(|attribute| {
// str::from_utf8(attribute.value.as_ref())?
// .try_into()?
// .as_ref()
// })?,
// );
return Ok(()); return Ok(());
} else { } else {
return Err(JabberError::BadStream); return Err(JabberError::BadStream);
} }
} }
// TODO: errors for incorrect namespace // TODO: errors for incorrect namespace
(quick_xml::name::ResolveResult::Unbound, Event::Decl(_)) => todo!(), _ => Err(JabberError::BadStream),
(quick_xml::name::ResolveResult::Unknown(_), Event::Start(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::End(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::Empty(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::Text(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::CData(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::Comment(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::Decl(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::PI(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::DocType(_)) => todo!(),
(quick_xml::name::ResolveResult::Unknown(_), Event::Eof) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::Start(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::End(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::Empty(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::Text(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::CData(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::Comment(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::PI(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::DocType(_)) => todo!(),
(quick_xml::name::ResolveResult::Unbound, Event::Eof) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::End(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::Empty(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::Text(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::CData(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::Comment(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::Decl(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::PI(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::DocType(_)) => todo!(),
(quick_xml::name::ResolveResult::Bound(_), Event::Eof) => todo!(),
} }
} }
} }
@ -164,7 +124,7 @@ where
// } // }
impl Jabber<Unencrypted> { impl Jabber<Unencrypted> {
pub async fn starttls(mut self) -> Result<Jabber<Tls>> { pub async fn starttls(&mut self) -> Result<Jabber<Tls>> {
todo!() todo!()
} }
// let mut starttls_element = BytesStart::new("starttls"); // let mut starttls_element = BytesStart::new("starttls");
@ -203,3 +163,41 @@ impl Jabber<Unencrypted> {
// Err(JabberError::TlsNegotiation) // Err(JabberError::TlsNegotiation)
// } // }
} }
impl std::fmt::Debug for Jabber<Tls> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"tls")
.field("jid", &self.jid)
.field("auth", &self.auth)
.field("server", &self.server)
.finish()
}
}
impl std::fmt::Debug for Jabber<Unencrypted> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"unencrypted")
.field("jid", &self.jid)
.field("auth", &self.auth)
.field("server", &self.server)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connection::Connection;
use test_log::test;
#[test(tokio::test)]
async fn start_stream() {
let connection = Connection::connect("blos.sm").await.unwrap();
match connection {
Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
}
}
}

View File

@ -1,5 +1,5 @@
#![allow(unused_must_use)] #![allow(unused_must_use)]
#![feature(let_chains)] // #![feature(let_chains)]
// TODO: logging (dropped errors) // TODO: logging (dropped errors)
pub mod connection; pub mod connection;