switch to using peanuts for xml

This commit is contained in:
cel 🌸 2024-11-23 22:39:44 +00:00
parent 9f2546f6da
commit 40024d2dad
8 changed files with 184 additions and 141 deletions

View File

@ -11,16 +11,14 @@ async-recursion = "1.0.4"
async-trait = "0.1.68" async-trait = "0.1.68"
lazy_static = "1.4.0" lazy_static = "1.4.0"
nanoid = "0.4.0" nanoid = "0.4.0"
quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio", "serialize"] }
# TODO: remove unneeded features # TODO: remove unneeded features
rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }
serde = "1.0.180"
serde_with = "3.4.0"
tokio = { version = "1.28", features = ["full"] } tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
tracing = "0.1.40" tracing = "0.1.40"
trust-dns-resolver = "0.22.0" trust-dns-resolver = "0.22.0"
try_map = "0.3.1" try_map = "0.3.1"
peanuts = { version = "0.1.0", path = "../peanuts" }
[dev-dependencies] [dev-dependencies]
test-log = { version = "0.2", features = ["trace"] } test-log = { version = "0.2", features = ["trace"] }

View File

@ -8,8 +8,8 @@ use tokio_native_tls::native_tls::TlsConnector;
use tokio_native_tls::TlsStream; use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace}; use tracing::{debug, info, instrument, trace};
use crate::Error;
use crate::Jabber; use crate::Jabber;
use crate::JabberError;
use crate::Result; use crate::Result;
pub type Tls = TlsStream<TcpStream>; pub type Tls = TlsStream<TcpStream>;
@ -75,7 +75,7 @@ impl Connection {
} }
} }
} }
Err(JabberError::Connection) Err(Error::Connection)
} }
#[instrument] #[instrument]
@ -154,19 +154,19 @@ impl Connection {
pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> { pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
let socket = TcpStream::connect(socket_addr) let socket = TcpStream::connect(socket_addr)
.await .await
.map_err(|_| JabberError::Connection)?; .map_err(|_| Error::Connection)?;
let connector = TlsConnector::new().map_err(|_| JabberError::Connection)?; let connector = TlsConnector::new().map_err(|_| Error::Connection)?;
tokio_native_tls::TlsConnector::from(connector) tokio_native_tls::TlsConnector::from(connector)
.connect(domain_name, socket) .connect(domain_name, socket)
.await .await
.map_err(|_| JabberError::Connection) .map_err(|_| Error::Connection)
} }
#[instrument] #[instrument]
pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> { pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
TcpStream::connect(socket_addr) TcpStream::connect(socket_addr)
.await .await
.map_err(|_| JabberError::Connection) .map_err(|_| Error::Connection)
} }
} }

View File

@ -1,12 +1,11 @@
use std::str::Utf8Error; use std::str::Utf8Error;
use quick_xml::events::attributes::AttrError;
use rsasl::mechname::MechanismNameError; use rsasl::mechname::MechanismNameError;
use crate::jid::ParseError; use crate::jid::ParseError;
#[derive(Debug)] #[derive(Debug)]
pub enum JabberError { pub enum Error {
Connection, Connection,
BadStream, BadStream,
StartTlsUnavailable, StartTlsUnavailable,
@ -23,7 +22,7 @@ pub enum JabberError {
UnexpectedEnd, UnexpectedEnd,
UnexpectedElement, UnexpectedElement,
UnexpectedText, UnexpectedText,
XML(quick_xml::Error), XML(peanuts::Error),
SASL(SASLError), SASL(SASLError),
JID(ParseError), JID(ParseError),
} }
@ -36,43 +35,37 @@ pub enum SASLError {
NoSuccess, NoSuccess,
} }
impl From<rsasl::prelude::SASLError> for JabberError { impl From<rsasl::prelude::SASLError> for Error {
fn from(e: rsasl::prelude::SASLError) -> Self { fn from(e: rsasl::prelude::SASLError) -> Self {
Self::SASL(SASLError::SASL(e)) Self::SASL(SASLError::SASL(e))
} }
} }
impl From<MechanismNameError> for JabberError { impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self { fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e)) Self::SASL(SASLError::MechanismName(e))
} }
} }
impl From<SASLError> for JabberError { impl From<SASLError> for Error {
fn from(e: SASLError) -> Self { fn from(e: SASLError) -> Self {
Self::SASL(e) Self::SASL(e)
} }
} }
impl From<Utf8Error> for JabberError { impl From<Utf8Error> for Error {
fn from(_e: Utf8Error) -> Self { fn from(_e: Utf8Error) -> Self {
Self::Utf8Decode Self::Utf8Decode
} }
} }
impl From<quick_xml::Error> for JabberError { impl From<peanuts::Error> for Error {
fn from(e: quick_xml::Error) -> Self { fn from(e: peanuts::Error) -> Self {
Self::XML(e) Self::XML(e)
} }
} }
impl From<AttrError> for JabberError { impl From<ParseError> for Error {
fn from(e: AttrError) -> Self {
Self::XML(e.into())
}
}
impl From<ParseError> for JabberError {
fn from(e: ParseError) -> Self { fn from(e: ParseError) -> Self {
Self::JID(e) Self::JID(e)
} }

View File

@ -1,16 +1,15 @@
use std::str; use std::str;
use std::sync::Arc; use std::sync::Arc;
use quick_xml::{events::Event, se::Serializer, NsReader, Writer}; use peanuts::{Reader, Writer};
use rsasl::prelude::SASLConfig; use rsasl::prelude::SASLConfig;
use serde::Serialize;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tracing::{debug, info, trace}; use tracing::{debug, info, trace};
use crate::connection::{Tls, Unencrypted}; use crate::connection::{Tls, Unencrypted};
use crate::error::JabberError; use crate::error::Error;
use crate::stanza::stream::Stream; use crate::stanza::stream::Stream;
use crate::stanza::DECLARATION; use crate::stanza::XML_VERSION;
use crate::Result; use crate::Result;
use crate::JID; use crate::JID;
@ -18,8 +17,8 @@ pub struct Jabber<S>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
{ {
reader: NsReader<BufReader<ReadHalf<S>>>, reader: Reader<ReadHalf<S>>,
writer: WriteHalf<S>, writer: Writer<WriteHalf<S>>,
jid: Option<JID>, jid: Option<JID>,
auth: Option<Arc<SASLConfig>>, auth: Option<Arc<SASLConfig>>,
server: String, server: String,
@ -36,7 +35,8 @@ where
auth: Option<Arc<SASLConfig>>, auth: Option<Arc<SASLConfig>>,
server: String, server: String,
) -> Self { ) -> Self {
let reader = NsReader::from_reader(BufReader::new(reader)); let reader = Reader::new(reader);
let writer = Writer::new(writer);
Self { Self {
reader, reader,
writer, writer,
@ -49,7 +49,7 @@ where
impl<S> Jabber<S> impl<S> Jabber<S>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin + Send,
{ {
// pub async fn negotiate(self) -> Result<Jabber<S>> {} // pub async fn negotiate(self) -> Result<Jabber<S>> {}
@ -57,65 +57,26 @@ where
// client to server // client to server
// declaration // declaration
let mut xmlwriter = Writer::new(&mut self.writer); self.writer.write_declaration(XML_VERSION).await?;
xmlwriter.write_event_async(DECLARATION.clone()).await?;
// opening stream element // opening stream element
let server = &self.server.to_owned().try_into()?; let server = self.server.clone().try_into()?;
let stream_element = Stream::new_client(None, server, None, "en"); let stream = Stream::new_client(None, server, None, "en".to_string());
// TODO: nicer function to serialize to xml writer // TODO: nicer function to serialize to xml writer
let mut buffer = String::new(); self.writer.write_start(&stream).await?;
let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name");
stream_element.serialize(ser).unwrap();
trace!("sent: {}", buffer);
self.writer.write_all(buffer.as_bytes()).await.unwrap();
// server to client // server to client
// may or may not send a declaration // may or may not send a declaration
let mut buf = Vec::new(); let decl = self.reader.read_prolog().await?;
let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
trace!("received: {:?}", first_event);
match first_event {
(quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => {
if let Ok(version) = e.version() {
if version.as_ref() == b"1.0" {
first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
trace!("received: {:?}", first_event);
} else {
// todo: error
todo!()
}
} else {
first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
trace!("received: {:?}", first_event);
}
}
_ => (),
}
// receive stream element and validate // receive stream element and validate
match first_event { let stream: Stream = self.reader.read_start().await?;
(quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => { if let Some(from) = stream.from {
if ns.0 == crate::stanza::stream::XMLNS.as_bytes() { self.server = from.to_string()
e.attributes().try_for_each(|attr| -> Result<()> {
let attr = attr?;
match attr.key.into_inner() {
b"from" => {
self.server = str::from_utf8(&attr.value)?.to_owned();
Ok(())
}
_ => Ok(()),
}
});
return Ok(());
} else {
return Err(JabberError::BadStream);
}
}
// TODO: errors for incorrect namespace
_ => Err(JabberError::BadStream),
} }
Ok(())
} }
} }

View File

@ -1,7 +1,5 @@
use std::str::FromStr; use std::str::FromStr;
use serde::Serialize;
#[derive(PartialEq, Debug, Clone)] #[derive(PartialEq, Debug, Clone)]
pub struct JID { pub struct JID {
// TODO: validate localpart (length, char] // TODO: validate localpart (length, char]
@ -10,15 +8,6 @@ pub struct JID {
pub resourcepart: Option<String>, pub resourcepart: Option<String>,
} }
impl Serialize for JID {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
pub enum JIDError { pub enum JIDError {
NoResourcePart, NoResourcePart,
ParseError(ParseError), ParseError(ParseError),
@ -27,7 +16,16 @@ pub enum JIDError {
#[derive(Debug)] #[derive(Debug)]
pub enum ParseError { pub enum ParseError {
Empty, Empty,
Malformed, Malformed(String),
}
impl From<ParseError> for peanuts::Error {
fn from(e: ParseError) -> Self {
match e {
ParseError::Empty => peanuts::Error::DeserializeError("".to_string()),
ParseError::Malformed(e) => peanuts::Error::DeserializeError(e),
}
}
} }
impl JID { impl JID {
@ -76,7 +74,7 @@ impl FromStr for JID {
split[0].to_string(), split[0].to_string(),
Some(split[1].to_string()), Some(split[1].to_string()),
)), )),
_ => Err(ParseError::Malformed), _ => Err(ParseError::Malformed(s.to_string())),
} }
} }
2 => { 2 => {
@ -92,10 +90,10 @@ impl FromStr for JID {
split2[0].to_string(), split2[0].to_string(),
Some(split2[1].to_string()), Some(split2[1].to_string()),
)), )),
_ => Err(ParseError::Malformed), _ => Err(ParseError::Malformed(s.to_string())),
} }
} }
_ => Err(ParseError::Malformed), _ => Err(ParseError::Malformed(s.to_string())),
} }
} }
} }

View File

@ -12,11 +12,11 @@ pub mod stanza;
extern crate lazy_static; extern crate lazy_static;
pub use connection::Connection; pub use connection::Connection;
pub use error::JabberError; pub use error::Error;
pub use jabber::Jabber; pub use jabber::Jabber;
pub use jid::JID; pub use jid::JID;
pub type Result<T> = std::result::Result<T, JabberError>; pub type Result<T> = std::result::Result<T, Error>;
pub async fn login<J: TryInto<JID>, P: AsRef<str>>(jid: J, password: P) -> Result<Connection> { pub async fn login<J: TryInto<JID>, P: AsRef<str>>(jid: J, password: P) -> Result<Connection> {
todo!() todo!()

View File

@ -1,3 +1,5 @@
use peanuts::declaration::VersionInfo;
pub mod bind; pub mod bind;
pub mod iq; pub mod iq;
pub mod message; pub mod message;
@ -6,8 +8,4 @@ pub mod sasl;
pub mod starttls; pub mod starttls;
pub mod stream; pub mod stream;
use quick_xml::events::{BytesDecl, Event}; pub static XML_VERSION: VersionInfo = VersionInfo::One;
lazy_static! {
pub static ref DECLARATION: Event<'static> = Event::Decl(BytesDecl::new("1.0", None, None));
}

View File

@ -1,37 +1,141 @@
use serde::Serialize; use std::collections::{HashMap, HashSet};
use crate::JID; use peanuts::element::{Content, FromElement, IntoElement, NamespaceDeclaration};
use peanuts::XML_NS;
use peanuts::{element::Name, Element};
pub static XMLNS: &str = "http://etherx.jabber.org/streams"; use crate::{Error, JID};
pub static XMLNS_CLIENT: &str = "jabber:client";
pub const XMLNS: &str = "http://etherx.jabber.org/streams";
pub const XMLNS_CLIENT: &str = "jabber:client";
// MUST be qualified by stream namespace // MUST be qualified by stream namespace
#[derive(Serialize)] // #[derive(XmlSerialize, XmlDeserialize)]
pub struct Stream<'s> { // #[peanuts(xmlns = XMLNS)]
#[serde(rename = "@from")] pub struct Stream {
from: Option<&'s JID>, pub from: Option<JID>,
#[serde(rename = "@to")] to: Option<JID>,
to: Option<&'s JID>, id: Option<String>,
#[serde(rename = "@id")] version: Option<String>,
id: Option<&'s str>,
#[serde(rename = "@version")]
version: Option<&'s str>,
// TODO: lang enum // TODO: lang enum
#[serde(rename = "@lang")] lang: Option<String>,
lang: Option<&'s str>, // #[peanuts(content)]
#[serde(rename = "@xmlns")] // content: Message,
xmlns: &'s str,
#[serde(rename = "@xmlns:stream")]
xmlns_stream: &'s str,
} }
impl<'s> Stream<'s> { impl FromElement for Stream {
fn from_element(element: Element) -> peanuts::Result<Self> {
let Name {
namespace,
local_name,
} = element.name;
if namespace.as_deref() == Some(XMLNS) && &local_name == "stream" {
let (mut from, mut to, mut id, mut version, mut lang) = (None, None, None, None, None);
for (name, value) in element.attributes {
match (name.namespace.as_deref(), name.local_name.as_str()) {
(None, "from") => from = Some(value.try_into()?),
(None, "to") => to = Some(value.try_into()?),
(None, "id") => id = Some(value),
(None, "version") => version = Some(value),
(Some(XML_NS), "lang") => lang = Some(value),
_ => return Err(peanuts::Error::UnexpectedAttribute(name)),
}
}
return Ok(Stream {
from,
to,
id,
version,
lang,
});
} else {
return Err(peanuts::Error::IncorrectName(Name {
namespace,
local_name,
}));
}
}
}
impl IntoElement for Stream {
fn into_element(&self) -> Element {
let mut namespace_declarations = HashSet::new();
namespace_declarations.insert(NamespaceDeclaration {
prefix: Some("stream".to_string()),
namespace: XMLNS.to_string(),
});
namespace_declarations.insert(NamespaceDeclaration {
prefix: None,
// TODO: don't default to client
namespace: XMLNS_CLIENT.to_string(),
});
let mut attributes = HashMap::new();
self.from.as_ref().map(|from| {
attributes.insert(
Name {
namespace: None,
local_name: "from".to_string(),
},
from.to_string(),
);
});
self.to.as_ref().map(|to| {
attributes.insert(
Name {
namespace: None,
local_name: "to".to_string(),
},
to.to_string(),
);
});
self.id.as_ref().map(|id| {
attributes.insert(
Name {
namespace: None,
local_name: "version".to_string(),
},
id.clone(),
);
});
self.version.as_ref().map(|version| {
attributes.insert(
Name {
namespace: None,
local_name: "version".to_string(),
},
version.clone(),
);
});
self.lang.as_ref().map(|lang| {
attributes.insert(
Name {
namespace: Some(XML_NS.to_string()),
local_name: "lang".to_string(),
},
lang.to_string(),
);
});
Element {
name: Name {
namespace: Some(XMLNS.to_string()),
local_name: "stream".to_string(),
},
namespace_declarations,
attributes,
content: Vec::new(),
}
}
}
impl<'s> Stream {
pub fn new( pub fn new(
from: Option<&'s JID>, from: Option<JID>,
to: Option<&'s JID>, to: Option<JID>,
id: Option<&'s str>, id: Option<String>,
version: Option<&'s str>, version: Option<String>,
lang: Option<&'s str>, lang: Option<String>,
) -> Self { ) -> Self {
Self { Self {
from, from,
@ -39,27 +143,18 @@ impl<'s> Stream<'s> {
id, id,
version, version,
lang, lang,
xmlns: XMLNS_CLIENT,
xmlns_stream: XMLNS,
} }
} }
/// For initial stream headers, the initiating entity SHOULD include the 'xml:lang' attribute. /// For initial stream headers, the initiating entity SHOULD include the 'xml:lang' attribute.
/// For privacy, it is better to not set `from` when sending a client stanza over an unencrypted connection. /// For privacy, it is better to not set `from` when sending a client stanza over an unencrypted connection.
pub fn new_client( pub fn new_client(from: Option<JID>, to: JID, id: Option<String>, lang: String) -> Self {
from: Option<&'s JID>,
to: &'s JID,
id: Option<&'s str>,
lang: &'s str,
) -> Self {
Self { Self {
from, from,
to: Some(to), to: Some(to),
id, id,
version: Some("1.0"), version: Some("1.0".to_string()),
lang: Some(lang), lang: Some(lang),
xmlns: XMLNS_CLIENT,
xmlns_stream: XMLNS,
} }
} }
} }