implement resource binding

This commit is contained in:
cel 🌸 2023-08-02 00:56:38 +01:00
parent 322b2a3b46
commit cd7bb95c0a
9 changed files with 387 additions and 54 deletions

View File

@ -9,6 +9,7 @@ edition = "2021"
[dependencies] [dependencies]
async-recursion = "1.0.4" async-recursion = "1.0.4"
async-trait = "0.1.68" async-trait = "0.1.68"
nanoid = "0.4.0"
quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio"] } quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio"] }
# TODO: remove unneeded features # TODO: remove unneeded features
rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }

View File

@ -11,19 +11,22 @@ use tokio::net::TcpStream;
use tokio_native_tls::TlsStream; use tokio_native_tls::TlsStream;
use crate::stanza::{ use crate::stanza::{
sasl::{Auth, Response}, bind::Bind,
stream::{Stream, StreamFeature}, iq::IQ,
};
use crate::stanza::{
sasl::{Challenge, Success}, sasl::{Challenge, Success},
Element, Element,
}; };
use crate::stanza::{
sasl::{Auth, Response},
stream::{Stream, StreamFeature},
};
use crate::Jabber; use crate::Jabber;
use crate::JabberError;
use crate::Result; use crate::Result;
pub struct JabberClient<'j> { pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>, pub reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>, pub writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
jabber: &'j mut Jabber<'j>, jabber: &'j mut Jabber<'j>,
} }
@ -64,15 +67,19 @@ impl<'j> JabberClient<'j> {
pub async fn negotiate(&mut self) -> Result<()> { pub async fn negotiate(&mut self) -> Result<()> {
loop { loop {
println!("loop"); println!("negotiate loop");
let features = self.get_features().await?; let features = self.get_features().await?;
println!("features: {:?}", features); println!("features: {:?}", features);
match &features[0] { match &features[0] {
StreamFeature::Sasl(sasl) => { StreamFeature::Sasl(sasl) => {
println!("sasl?"); println!("sasl?");
self.sasl(&sasl).await?; self.sasl(&sasl).await?;
} }
StreamFeature::Bind => todo!(), StreamFeature::Bind => {
self.bind().await?;
return Ok(());
}
x => println!("{:?}", x), x => println!("{:?}", x),
} }
} }
@ -165,4 +172,36 @@ impl<'j> JabberClient<'j> {
self.start_stream().await?; self.start_stream().await?;
Ok(()) Ok(())
} }
pub async fn bind(&mut self) -> Result<()> {
match &self.jabber.jid.resourcepart {
Some(resource) => {
println!("setting resource");
let bind = Bind {
resource: Some(resource.clone()),
jid: None,
};
let result: Bind = IQ::set(self, None, None, bind).await?.try_into()?;
if let Some(jid) = result.jid {
println!("{}", jid);
self.jabber.jid = jid;
return Ok(());
}
}
None => {
println!("not setting resource");
let bind = Bind {
resource: None,
jid: None,
};
let result: Bind = IQ::set(self, None, None, bind).await?.try_into()?;
if let Some(jid) = result.jid {
println!("{}", jid);
self.jabber.jid = jid;
return Ok(());
}
}
}
Err(JabberError::BindError)
}
} }

View File

@ -17,8 +17,14 @@ pub enum JabberError {
Utf8Decode, Utf8Decode,
NoFeatures, NoFeatures,
UnknownNamespace, UnknownNamespace,
UnknownAttribute,
NoID,
NoType,
IDMismatch,
BindError,
ParseError, ParseError,
UnexpectedEnd, UnexpectedEnd,
UnexpectedElement,
XML(quick_xml::Error), XML(quick_xml::Error),
SASL(SASLError), SASL(SASLError),
Element(ElementError<'static>), Element(ElementError<'static>),

View File

@ -1,4 +1,5 @@
#![allow(unused_must_use)] #![allow(unused_must_use)]
#![feature(let_chains)]
// TODO: logging (dropped errors) // TODO: logging (dropped errors)
pub mod client; pub mod client;
@ -44,7 +45,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn login() { async fn login() {
Jabber::new(JID::from_str("test@blos.sm").unwrap(), "slayed".to_owned()) Jabber::new(
JID::from_str("test@blos.sm/clown").unwrap(),
"slayed".to_owned(),
)
.unwrap() .unwrap()
.login() .login()
.await .await

111
src/stanza/bind.rs Normal file
View File

@ -0,0 +1,111 @@
use quick_xml::{
events::{BytesStart, BytesText, Event},
name::QName,
Reader,
};
use super::{Element, IntoElement};
use crate::{JabberError, JID};
const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-bind";
pub struct Bind {
pub resource: Option<String>,
pub jid: Option<JID>,
}
impl<'e> IntoElement<'e> for Bind {
fn event(&self) -> quick_xml::events::Event<'static> {
let mut bind_event = BytesStart::new("bind");
bind_event.push_attribute(("xmlns", XMLNS));
if self.resource.is_none() && self.jid.is_none() {
return Event::Empty(bind_event);
} else {
return Event::Start(bind_event);
}
}
fn children(&self) -> Option<Vec<Element<'static>>> {
if let Some(resource) = &self.resource {
let resource_event: BytesStart<'static> = BytesStart::new("resource");
let resource_child: BytesText<'static> = BytesText::new(resource).into_owned();
let resource_child: Element<'static> = Element {
event: Event::Text(resource_child),
children: None,
};
let resource_element: Element<'static> = Element {
event: Event::Start(resource_event),
children: Some(vec![resource_child]),
};
return Some(vec![resource_element]);
} else if let Some(jid) = &self.jid {
let jid_event = BytesStart::new("jid");
let jid_child = BytesText::new(&jid.to_string()).into_owned();
let jid_child = Element {
event: Event::Text(jid_child),
children: None,
};
let jid_element = Element {
event: Event::Start(jid_event),
children: Some(vec![jid_child]),
};
return Some(vec![jid_element]);
}
None
}
}
impl TryFrom<Element<'static>> for Bind {
type Error = JabberError;
fn try_from(element: Element<'static>) -> Result<Self, Self::Error> {
if let Event::Start(start) = &element.event {
let buf: Vec<u8> = Vec::new();
let reader = Reader::from_reader(buf);
if start.name() == QName(b"bind")
&& start.try_get_attribute("xmlns")?.is_some_and(|attribute| {
attribute.decode_and_unescape_value(&reader).unwrap() == XMLNS
})
{
let child: Element<'static> = element.child()?.clone();
if let Event::Start(start) = &child.event {
match start.name() {
QName(b"resource") => {
let resource_text = child.child()?;
if let Event::Text(text) = &resource_text.event {
return Ok(Self {
resource: Some(text.unescape()?.into_owned()),
jid: None,
});
}
}
QName(b"jid") => {
let jid_text = child.child()?;
if let Event::Text(text) = &jid_text.event {
return Ok(Self {
jid: Some(text.unescape()?.into_owned().try_into()?),
resource: None,
});
}
}
_ => return Err(JabberError::UnexpectedElement),
}
}
}
} else if let Event::Empty(start) = &element.event {
let buf: Vec<u8> = Vec::new();
let reader = Reader::from_reader(buf);
if start.name() == QName(b"bind")
&& start.try_get_attribute("xmlns")?.is_some_and(|attribute| {
attribute.decode_and_unescape_value(&reader).unwrap() == XMLNS
})
{
return Ok(Bind {
resource: None,
jid: None,
});
}
}
Err(JabberError::UnexpectedElement)
}
}

171
src/stanza/iq.rs Normal file
View File

@ -0,0 +1,171 @@
use nanoid::nanoid;
use quick_xml::{
events::{BytesStart, Event},
name::QName,
Reader, Writer,
};
use crate::{JabberClient, JabberError, JID};
use super::{Element, IntoElement};
use crate::Result;
#[derive(Debug)]
pub struct IQ {
to: Option<JID>,
from: Option<JID>,
id: String,
r#type: IQType,
lang: Option<String>,
child: Element<'static>,
}
#[derive(Debug)]
enum IQType {
Get,
Set,
Result,
Error,
}
impl IQ {
pub async fn set<'j, R: IntoElement<'static>>(
client: &mut JabberClient<'j>,
to: Option<JID>,
from: Option<JID>,
element: R,
) -> Result<Element<'static>> {
let id = nanoid!();
let iq = IQ {
to,
from,
id: id.clone(),
r#type: IQType::Set,
lang: None,
child: Element::from(element),
};
println!("{:?}", iq);
let iq = Element::from(iq);
println!("{:?}", iq);
iq.write(&mut client.writer).await?;
let result = Element::read(&mut client.reader).await?;
let iq = IQ::try_from(result)?;
if iq.id == id {
return Ok(iq.child);
}
Err(JabberError::IDMismatch)
}
}
impl<'e> IntoElement<'e> for IQ {
fn event(&self) -> quick_xml::events::Event<'e> {
let mut start = BytesStart::new("iq");
if let Some(to) = &self.to {
start.push_attribute(("to", to.to_string().as_str()));
}
if let Some(from) = &self.from {
start.push_attribute(("from", from.to_string().as_str()));
}
start.push_attribute(("id", self.id.as_str()));
match self.r#type {
IQType::Get => start.push_attribute(("type", "get")),
IQType::Set => start.push_attribute(("type", "set")),
IQType::Result => start.push_attribute(("type", "result")),
IQType::Error => start.push_attribute(("type", "error")),
}
if let Some(lang) = &self.lang {
start.push_attribute(("from", lang.to_string().as_str()));
}
quick_xml::events::Event::Start(start)
}
fn children(&self) -> Option<Vec<Element<'e>>> {
Some(vec![self.child.clone()])
}
}
impl TryFrom<Element<'static>> for IQ {
type Error = JabberError;
fn try_from(element: Element<'static>) -> std::result::Result<Self, Self::Error> {
if let Event::Start(start) = &element.event {
if start.name() == QName(b"iq") {
let mut to: Option<JID> = None;
let mut from: Option<JID> = None;
let mut id = None;
let mut r#type = None;
let mut lang = None;
start
.attributes()
.into_iter()
.try_for_each(|attribute| -> Result<()> {
if let Ok(attribute) = attribute {
let buf: Vec<u8> = Vec::new();
let reader = Reader::from_reader(buf);
match attribute.key {
QName(b"to") => {
to = Some(
attribute
.decode_and_unescape_value(&reader)
.or(Err(JabberError::Utf8Decode))?
.into_owned()
.try_into()?,
)
}
QName(b"from") => {
from = Some(
attribute
.decode_and_unescape_value(&reader)
.or(Err(JabberError::Utf8Decode))?
.into_owned()
.try_into()?,
)
}
QName(b"id") => {
id = Some(
attribute
.decode_and_unescape_value(&reader)
.or(Err(JabberError::Utf8Decode))?
.into_owned(),
)
}
QName(b"type") => {
let value = attribute
.decode_and_unescape_value(&reader)
.or(Err(JabberError::Utf8Decode))?;
match value.as_ref() {
"get" => r#type = Some(IQType::Get),
"set" => r#type = Some(IQType::Set),
"result" => r#type = Some(IQType::Result),
"error" => r#type = Some(IQType::Error),
_ => return Err(JabberError::ParseError),
}
}
QName(b"lang") => {
lang = Some(
attribute
.decode_and_unescape_value(&reader)
.or(Err(JabberError::Utf8Decode))?
.into_owned(),
)
}
_ => return Err(JabberError::UnknownAttribute),
}
}
Ok(())
})?;
let iq = IQ {
to,
from,
id: id.ok_or(JabberError::NoID)?,
r#type: r#type.ok_or(JabberError::NoType)?,
lang,
child: element.child()?.to_owned(),
};
return Ok(iq);
}
}
Err(JabberError::ParseError)
}
}

View File

@ -1,5 +1,7 @@
// use quick_xml::events::BytesDecl; // use quick_xml::events::BytesDecl;
pub mod bind;
pub mod iq;
pub mod sasl; pub mod sasl;
pub mod stream; pub mod stream;
@ -128,11 +130,10 @@ impl<'e> Element<'e> {
e => Err(ElementError::NotAStart(e.into_owned()).into()), e => Err(ElementError::NotAStart(e.into_owned()).into()),
} }
} }
}
/// if there is only one child in the vec of children, will return that element /// if there is only one child in the vec of children, will return that element
pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, ElementError<'static>> { pub fn child<'p>(&'p self) -> Result<&'p Element<'e>, ElementError<'static>> {
if let Some(children) = &element.children { if let Some(children) = &self.children {
if children.len() == 1 { if children.len() == 1 {
return Ok(&children[0]); return Ok(&children[0]);
} else { } else {
@ -143,14 +144,27 @@ pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, Elemen
} }
/// returns reference to children /// returns reference to children
pub fn children<'p, 'e>( pub fn children<'p>(&'p self) -> Result<&'p Vec<Element<'e>>, ElementError<'e>> {
element: &'p Element<'e>, if let Some(children) = &self.children {
) -> Result<&'p Vec<Element<'e>>, ElementError<'e>> {
if let Some(children) = &element.children {
return Ok(children); return Ok(children);
} }
Err(ElementError::NoChildren) Err(ElementError::NoChildren)
} }
}
pub trait IntoElement<'e> {
fn event(&self) -> Event<'e>;
fn children(&self) -> Option<Vec<Element<'e>>>;
}
impl<'e, T: IntoElement<'e>> From<T> for Element<'e> {
fn from(value: T) -> Self {
Element {
event: value.event(),
children: value.children(),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub enum ElementError<'e> { pub enum ElementError<'e> {

View File

@ -7,6 +7,7 @@ use crate::error::SASLError;
use crate::JabberError; use crate::JabberError;
use super::Element; use super::Element;
use super::IntoElement;
const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
@ -16,7 +17,7 @@ pub struct Auth<'e> {
pub sasl_data: &'e str, pub sasl_data: &'e str,
} }
impl<'e> Auth<'e> { impl<'e> IntoElement<'e> for Auth<'e> {
fn event(&self) -> Event<'e> { fn event(&self) -> Event<'e> {
let mut start = BytesStart::new("auth"); let mut start = BytesStart::new("auth");
start.push_attribute(("xmlns", XMLNS)); start.push_attribute(("xmlns", XMLNS));
@ -34,15 +35,6 @@ impl<'e> Auth<'e> {
} }
} }
impl<'e> Into<Element<'e>> for Auth<'e> {
fn into(self) -> Element<'e> {
Element {
event: self.event(),
children: self.children(),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct Challenge { pub struct Challenge {
pub sasl_data: Vec<u8>, pub sasl_data: Vec<u8>,
@ -54,7 +46,7 @@ impl<'e> TryFrom<&Element<'e>> for Challenge {
fn try_from(element: &Element<'e>) -> Result<Challenge, Self::Error> { fn try_from(element: &Element<'e>) -> Result<Challenge, Self::Error> {
if let Event::Start(start) = &element.event { if let Event::Start(start) = &element.event {
if start.name() == QName(b"challenge") { if start.name() == QName(b"challenge") {
let sasl_data: &Element<'_> = super::child(element)?; let sasl_data: &Element<'_> = element.child()?;
if let Event::Text(sasl_data) = &sasl_data.event { if let Event::Text(sasl_data) = &sasl_data.event {
let s = sasl_data.clone(); let s = sasl_data.clone();
let s = s.into_inner(); let s = s.into_inner();
@ -101,7 +93,7 @@ pub struct Response<'e> {
pub sasl_data: &'e str, pub sasl_data: &'e str,
} }
impl<'e> Response<'e> { impl<'e> IntoElement<'e> for Response<'e> {
fn event(&self) -> Event<'e> { fn event(&self) -> Event<'e> {
let mut start = BytesStart::new("response"); let mut start = BytesStart::new("response");
start.push_attribute(("xmlns", XMLNS)); start.push_attribute(("xmlns", XMLNS));
@ -118,15 +110,6 @@ impl<'e> Response<'e> {
} }
} }
impl<'e> Into<Element<'e>> for Response<'e> {
fn into(self) -> Element<'e> {
Element {
event: self.event(),
children: self.children(),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct Success { pub struct Success {
pub sasl_data: Option<Vec<u8>>, pub sasl_data: Option<Vec<u8>>,
@ -139,7 +122,7 @@ impl<'e> TryFrom<&Element<'e>> for Success {
match &element.event { match &element.event {
Event::Start(start) => { Event::Start(start) => {
if start.name() == QName(b"success") { if start.name() == QName(b"success") {
match super::child(element) { match element.child() {
Ok(sasl_data) => { Ok(sasl_data) => {
if let Event::Text(sasl_data) = &sasl_data.event { if let Event::Text(sasl_data) = &sasl_data.event {
return Ok(Success { return Ok(Success {

View File

@ -175,7 +175,11 @@ impl<'e> TryFrom<Element<'e>> for Vec<StreamFeature> {
} }
features.push(StreamFeature::Sasl(mechanisms)) features.push(StreamFeature::Sasl(mechanisms))
} }
_ => {} _ => features.push(StreamFeature::Unknown),
},
Event::Empty(e) => match e.name() {
QName(b"bind") => features.push(StreamFeature::Bind),
_ => features.push(StreamFeature::Unknown),
}, },
_ => features.push(StreamFeature::Unknown), _ => features.push(StreamFeature::Unknown),
} }