diff --git a/src/stanza/starttls.rs b/src/stanza/starttls.rs index ee35bf5..33721ab 100644 --- a/src/stanza/starttls.rs +++ b/src/stanza/starttls.rs @@ -7,76 +7,48 @@ use peanuts::{ pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct StartTls { pub required: bool, } impl IntoElement for StartTls { - fn into_element(&self) -> peanuts::Element { - let content; - if self.required == true { - let element = Content::Element(Element { - name: Name { - namespace: Some(XMLNS.to_string()), - local_name: "required".to_string(), - }, - namespace_declaration_overrides: HashSet::new(), - attributes: HashMap::new(), - content: Vec::new(), - }); - content = vec![element]; - } else { - content = Vec::new(); - } - Element { - name: Name { - namespace: Some(XMLNS.to_string()), - local_name: "starttls".to_string(), - }, - namespace_declaration_overrides: HashSet::new(), - attributes: HashMap::new(), - content, + fn builder(&self) -> peanuts::element::ElementBuilder { + let mut builder = Element::builder("starttls", Some(XMLNS)); + + if self.required { + builder = builder.push_child(Element::builder("required", Some(XMLNS))) } + + builder } } impl FromElement for StartTls { - fn from_element(element: peanuts::Element) -> peanuts::Result { - let Name { - namespace, - local_name, - } = element.name; - if namespace.as_deref() == Some(XMLNS) && &local_name == "starttls" { - let mut required = false; - if element.content.len() == 1 { - match element.content.first().unwrap() { - Content::Element(element) => { - let Name { - namespace, - local_name, - } = &element.name; + fn from_element( + mut element: peanuts::Element, + ) -> std::result::Result { + element.check_name("starttls")?; + element.check_namespace(XMLNS)?; - if namespace.as_deref() == Some(XMLNS) && local_name == "required" { - required = true - } else { - return Err(peanuts::Error::UnexpectedElement(element.name.clone())); - } - } - c => return Err(peanuts::Error::UnexpectedContent((*c).clone())), - } - } else { - return Err(peanuts::Error::UnexpectedNumberOfContents( - element.content.len(), - )); - } - return Ok(StartTls { required }); - } else { - return Err(peanuts::Error::IncorrectName(Name { - namespace, - local_name, - })); + let mut required = false; + if let Some(_) = element.child_opt::()? { + required = true; } + + Ok(StartTls { required }) + } +} + +#[derive(Debug)] +pub struct Required; + +impl FromElement for Required { + fn from_element(element: Element) -> peanuts::element::DeserializeResult { + element.check_name("required")?; + element.check_namespace(XMLNS)?; + + Ok(Required) } } @@ -84,65 +56,33 @@ impl FromElement for StartTls { pub struct Proceed; impl IntoElement for Proceed { - fn into_element(&self) -> Element { - Element { - name: Name { - namespace: Some(XMLNS.to_string()), - local_name: "proceed".to_string(), - }, - namespace_declaration_overrides: HashSet::new(), - attributes: HashMap::new(), - content: Vec::new(), - } + fn builder(&self) -> peanuts::element::ElementBuilder { + Element::builder("proceed", Some(XMLNS)) } } impl FromElement for Proceed { - fn from_element(element: Element) -> peanuts::Result { - let Name { - namespace, - local_name, - } = element.name; - if namespace.as_deref() == Some(XMLNS) && &local_name == "proceed" { - return Ok(Proceed); - } else { - return Err(peanuts::Error::IncorrectName(Name { - namespace, - local_name, - })); - } + fn from_element(element: Element) -> peanuts::element::DeserializeResult { + element.check_name("proceed")?; + element.check_namespace(XMLNS)?; + + Ok(Proceed) } } pub struct Failure; impl IntoElement for Failure { - fn into_element(&self) -> Element { - Element { - name: Name { - namespace: Some(XMLNS.to_string()), - local_name: "failure".to_string(), - }, - namespace_declaration_overrides: HashSet::new(), - attributes: HashMap::new(), - content: Vec::new(), - } + fn builder(&self) -> peanuts::element::ElementBuilder { + Element::builder("failure", Some(XMLNS)) } } impl FromElement for Failure { - fn from_element(element: Element) -> peanuts::Result { - let Name { - namespace, - local_name, - } = element.name; - if namespace.as_deref() == Some(XMLNS) && &local_name == "failure" { - return Ok(Failure); - } else { - return Err(peanuts::Error::IncorrectName(Name { - namespace, - local_name, - })); - } + fn from_element(element: Element) -> peanuts::element::DeserializeResult { + element.check_name("failure")?; + element.check_namespace(XMLNS)?; + + Ok(Failure) } } diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index a5201dc..40f6ba0 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -1,12 +1,12 @@ use std::collections::{HashMap, HashSet}; -use peanuts::element::{Content, FromElement, IntoElement, NamespaceDeclaration}; +use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration}; use peanuts::XML_NS; use peanuts::{element::Name, Element}; use crate::{Error, JID}; -use super::starttls::StartTls; +use super::starttls::{self, StartTls}; pub const XMLNS: &str = "http://etherx.jabber.org/streams"; pub const XMLNS_CLIENT: &str = "jabber:client"; @@ -27,108 +27,36 @@ pub struct Stream { } impl FromElement for Stream { - fn from_element(element: Element) -> peanuts::Result { - let Name { - namespace, - local_name, - } = element.name; - if namespace.as_deref() == Some(XMLNS) && &local_name == "stream" { - let (mut from, mut to, mut id, mut version, mut lang) = (None, None, None, None, None); - for (name, value) in element.attributes { - match (name.namespace.as_deref(), name.local_name.as_str()) { - (None, "from") => from = Some(value.try_into()?), - (None, "to") => to = Some(value.try_into()?), - (None, "id") => id = Some(value), - (None, "version") => version = Some(value), - (Some(XML_NS), "lang") => lang = Some(value), - _ => return Err(peanuts::Error::UnexpectedAttribute(name)), - } - } - return Ok(Stream { - from, - to, - id, - version, - lang, - }); - } else { - return Err(peanuts::Error::IncorrectName(Name { - namespace, - local_name, - })); - } + fn from_element(mut element: Element) -> std::result::Result { + element.check_namespace(XMLNS)?; + element.check_name("stream")?; + + let from = element.attribute_opt("from")?; + let to = element.attribute_opt("to")?; + let id = element.attribute_opt("id")?; + let version = element.attribute_opt("version")?; + let lang = element.attribute_opt_namespaced("lang", peanuts::XML_NS)?; + + Ok(Stream { + from, + to, + id, + version, + lang, + }) } } impl IntoElement for Stream { - fn into_element(&self) -> Element { - let mut namespace_declaration_overrides = HashSet::new(); - namespace_declaration_overrides.insert(NamespaceDeclaration { - prefix: Some("stream".to_string()), - namespace: XMLNS.to_string(), - }); - namespace_declaration_overrides.insert(NamespaceDeclaration { - prefix: None, - // TODO: don't default to client - namespace: XMLNS_CLIENT.to_string(), - }); - - let mut attributes = HashMap::new(); - self.from.as_ref().map(|from| { - attributes.insert( - Name { - namespace: None, - local_name: "from".to_string(), - }, - from.to_string(), - ); - }); - self.to.as_ref().map(|to| { - attributes.insert( - Name { - namespace: None, - local_name: "to".to_string(), - }, - to.to_string(), - ); - }); - self.id.as_ref().map(|id| { - attributes.insert( - Name { - namespace: None, - local_name: "id".to_string(), - }, - id.clone(), - ); - }); - self.version.as_ref().map(|version| { - attributes.insert( - Name { - namespace: None, - local_name: "version".to_string(), - }, - version.clone(), - ); - }); - self.lang.as_ref().map(|lang| { - attributes.insert( - Name { - namespace: Some(XML_NS.to_string()), - local_name: "lang".to_string(), - }, - lang.to_string(), - ); - }); - - Element { - name: Name { - namespace: Some(XMLNS.to_string()), - local_name: "stream".to_string(), - }, - namespace_declaration_overrides, - attributes, - content: Vec::new(), - } + fn builder(&self) -> ElementBuilder { + Element::builder("stream", Some(XMLNS.to_string())) + .push_namespace_declaration_override(Some("stream"), XMLNS) + .push_namespace_declaration_override(None::<&str>, XMLNS_CLIENT) + .push_attribute_opt("to", self.to.clone()) + .push_attribute_opt("from", self.from.clone()) + .push_attribute_opt("id", self.id.clone()) + .push_attribute_opt("version", self.version.clone()) + .push_attribute_opt_namespaced(peanuts::XML_NS, "to", self.lang.clone()) } } @@ -168,64 +96,70 @@ pub struct Features { } impl IntoElement for Features { - fn into_element(&self) -> Element { - let mut content = Vec::new(); - for feature in &self.features { - match feature { - Feature::StartTls(start_tls) => { - content.push(Content::Element(start_tls.into_element())) - } - Feature::Sasl => {} - Feature::Bind => {} - Feature::Unknown => {} - } - } - Element { - name: Name { - namespace: Some(XMLNS.to_string()), - local_name: "features".to_string(), - }, - namespace_declaration_overrides: HashSet::new(), - attributes: HashMap::new(), - content, - } + fn builder(&self) -> ElementBuilder { + Element::builder("features", Some(XMLNS)).push_children(self.features.clone()) + // let mut content = Vec::new(); + // for feature in &self.features { + // match feature { + // Feature::StartTls(start_tls) => { + // content.push(Content::Element(start_tls.into_element())) + // } + // Feature::Sasl => {} + // Feature::Bind => {} + // Feature::Unknown => {} + // } + // } + // Element { + // name: Name { + // namespace: Some(XMLNS.to_string()), + // local_name: "features".to_string(), + // }, + // namespace_declaration_overrides: HashSet::new(), + // attributes: HashMap::new(), + // content, + // } } } impl FromElement for Features { - fn from_element(element: Element) -> peanuts::Result { - let Name { - namespace, - local_name, - } = element.name; - if namespace.as_deref() == Some(XMLNS) && &local_name == "features" { - let mut features = Vec::new(); - for feature in element.content { - match feature { - Content::Element(element) => { - if let Ok(start_tls) = FromElement::from_element(element) { - features.push(Feature::StartTls(start_tls)) - } else { - features.push(Feature::Unknown) - } - } - c => return Err(peanuts::Error::UnexpectedContent(c.clone())), - } - } - return Ok(Self { features }); - } else { - return Err(peanuts::Error::IncorrectName(Name { - namespace, - local_name, - })); - } + fn from_element( + mut element: Element, + ) -> std::result::Result { + element.check_namespace(XMLNS)?; + element.check_name("features")?; + + let features = element.children()?; + + Ok(Features { features }) } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Feature { StartTls(StartTls), Sasl, Bind, Unknown, } + +impl IntoElement for Feature { + fn builder(&self) -> ElementBuilder { + match self { + Feature::StartTls(start_tls) => start_tls.builder(), + Feature::Sasl => todo!(), + Feature::Bind => todo!(), + Feature::Unknown => todo!(), + } + } +} + +impl FromElement for Feature { + fn from_element(element: Element) -> peanuts::element::DeserializeResult { + match element.identify() { + (Some(starttls::XMLNS), "starttls") => { + Ok(Feature::StartTls(StartTls::from_element(element)?)) + } + _ => Ok(Feature::Unknown), + } + } +}