diff --git a/src/endable.rs b/src/endable.rs new file mode 100644 index 0000000..6006080 --- /dev/null +++ b/src/endable.rs @@ -0,0 +1,33 @@ +use crate::Error; + +#[derive(Debug)] +pub struct Endable { + inner: T, + ended: bool, +} + +impl Endable { + pub fn new(inner: T) -> Self { + Self { + inner, + ended: false, + } + } + + pub fn end(&mut self) { + self.ended = true; + } + + pub fn into_inner(self) -> T { + self.inner + } + + #[inline(always)] + pub fn try_as_mut(&mut self) -> Result<&mut T, Error> { + if self.ended { + Err(Error::RootElementEnded) + } else { + Ok(&mut self.inner) + } + } +} diff --git a/src/error.rs b/src/error.rs index fff59d4..cf01895 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,6 +24,7 @@ pub enum DeserializeError { } #[derive(Debug)] +// TODO: thiserror pub enum Error { ReadError(std::io::Error), Utf8Error(Utf8Error), @@ -41,6 +42,8 @@ pub enum Error { IncorrectName(Name), DeserializeError(String), Deserialize(DeserializeError), + /// root element end tag already processed + RootElementEnded, } impl From for Error { diff --git a/src/lib.rs b/src/lib.rs index 30e6051..26a3f78 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod declaration; pub mod element; +mod endable; mod error; pub mod reader; mod writer; diff --git a/src/reader.rs b/src/reader.rs index 074ab99..24cc098 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -30,6 +30,7 @@ pub struct Reader { // to have names reference namespaces could depth: Vec, namespace_declarations: Vec>, + root_ended: bool, } impl Reader { @@ -49,6 +50,7 @@ impl Reader { depth: Vec::new(), // TODO: make sure reserved namespaces are never overwritten namespace_declarations: vec![default_declarations], + root_ended: false, } } @@ -66,6 +68,9 @@ where } pub async fn read_prolog<'s>(&'s mut self) -> Result> { + if self.root_ended { + return Err(Error::RootElementEnded); + } loop { let input = str::from_utf8(self.buffer.data())?; match xml::Prolog::parse(input) { @@ -114,6 +119,9 @@ where } pub async fn read_start_tag<'s>(&'s mut self) -> Result { + if self.root_ended { + return Err(Error::RootElementEnded); + } loop { let input = str::from_utf8(self.buffer.data())?; match xml::STag::parse(input) { @@ -140,6 +148,9 @@ where } pub async fn read_end_tag<'s>(&'s mut self) -> Result<()> { + if self.root_ended { + return Err(Error::RootElementEnded); + } loop { let input = str::from_utf8(self.buffer.data())?; match xml::ETag::parse(input) { @@ -150,6 +161,9 @@ where &mut self.namespace_declarations, e, )?; + if self.depth.is_empty() { + self.root_ended = true + } self.buffer.consume(len); return Ok(()); } @@ -166,6 +180,9 @@ where } pub async fn read_element<'s>(&'s mut self) -> Result { + if self.root_ended { + return Err(Error::RootElementEnded); + } loop { let input = str::from_utf8(self.buffer.data())?; match xml::Element::parse(input) { @@ -173,6 +190,9 @@ where let len = self.buffer.available_data() - rest.as_bytes().len(); let element = Reader::::element_from_xml(&mut self.namespace_declarations, e)?; + if self.depth.is_empty() { + self.root_ended = true + } self.buffer.consume(len); return Ok(element); } @@ -189,6 +209,9 @@ where } pub async fn read_content<'s>(&'s mut self) -> Result { + if self.root_ended { + return Err(Error::RootElementEnded); + } let mut last_char = false; let mut text = String::new(); loop { @@ -217,6 +240,9 @@ where &mut self.namespace_declarations, element, )?; + if self.depth.is_empty() { + self.root_ended = true + } self.buffer.consume(len); return Ok(Content::Element(element)); } @@ -279,6 +305,9 @@ where &mut self.namespace_declarations, element, )?; + if self.depth.is_empty() { + self.root_ended = true + } self.buffer.consume(len); return Ok(Content::Element(element)); } @@ -722,7 +751,7 @@ impl Reader { } #[cfg(test)] -pub(crate) mod test { +pub mod test { use tokio::io::AsyncRead; use super::Reader; diff --git a/src/writer.rs b/src/writer.rs index 013d37b..f622bcf 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -13,6 +13,7 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use crate::{ declaration::{Declaration, VersionInfo}, element::{escape_str, Content, Element, IntoContent, IntoElement, Name, NamespaceDeclaration}, + endable::Endable, error::Error, xml::{self, composers::Composer, parsers_complete::Parser, ETag, XMLDecl}, Result, XMLNS_NS, XML_NS, @@ -21,7 +22,7 @@ use crate::{ // pub struct Writer { #[derive(Debug)] pub struct Writer { - inner: W, + inner: Endable, depth: Vec, namespace_declarations: Vec>, } @@ -38,19 +39,20 @@ impl Writer { namespace: XMLNS_NS.to_string(), }); Self { - inner: writer, + inner: Endable::new(writer), depth: Vec::new(), namespace_declarations: vec![default_declarations], } } pub fn into_inner(self) -> W { - self.inner + self.inner.into_inner() } } impl Writer { pub async fn write_declaration(&mut self, version: VersionInfo) -> Result<()> { + let writer = self.inner.try_as_mut()?; let declaration = Declaration::version(version); let version_info; match declaration.version_info { @@ -64,7 +66,7 @@ impl Writer { encoding_decl: None, sd_decl: None, }; - declaration.write(&mut self.inner).await?; + declaration.write(writer).await?; Ok(()) } @@ -105,6 +107,7 @@ impl Writer { } pub async fn write_empty(&mut self, element: &Element) -> Result<()> { + let writer = self.inner.try_as_mut()?; let mut namespace_declarations_stack: Vec<_> = self .namespace_declarations .iter() @@ -204,12 +207,17 @@ impl Writer { let tag = xml::EmptyElemTag { name, attributes }; - tag.write(&mut self.inner).await?; + tag.write(writer).await?; + + if self.depth.is_empty() { + self.inner.end(); + } Ok(()) } pub async fn write_element_start(&mut self, element: &Element) -> Result<()> { + let writer = self.inner.try_as_mut()?; let mut namespace_declarations_stack: Vec<_> = self .namespace_declarations .iter() @@ -309,7 +317,7 @@ impl Writer { let s_tag = xml::STag { name, attributes }; - s_tag.write(&mut self.inner).await?; + s_tag.write(writer).await?; self.depth.push(element.name.clone()); self.namespace_declarations @@ -320,7 +328,12 @@ impl Writer { pub async fn write_content(&mut self, content: &Content) -> Result<()> { match content { Content::Element(element) => self.write_element(element).await?, - Content::Text(text) => self.inner.write_all(escape_str(text).as_bytes()).await?, + Content::Text(text) => { + self.inner + .try_as_mut()? + .write_all(escape_str(text).as_bytes()) + .await? + } // TODO: comments and PI Content::PI => {} Content::Comment(_) => {} @@ -329,6 +342,7 @@ impl Writer { } pub async fn write_end(&mut self) -> Result<()> { + let writer = self.inner.try_as_mut()?; if let Some(name) = &self.depth.pop() { let e_tag; let namespace_declarations_stack: Vec<_> = @@ -359,8 +373,12 @@ impl Writer { )?), }; } - e_tag.write(&mut self.inner).await?; + e_tag.write(writer).await?; self.namespace_declarations.pop(); + + if self.depth.is_empty() { + self.inner.end(); + } Ok(()) } else { return Err(Error::NotInElement("".to_string()));