return error when attempt to read/write more than one root element in document

This commit is contained in:
cel 🌸 2025-01-12 16:46:14 +00:00
parent 4f0691de7d
commit bbb1452905
5 changed files with 93 additions and 9 deletions

33
src/endable.rs Normal file
View File

@ -0,0 +1,33 @@
use crate::Error;
#[derive(Debug)]
pub struct Endable<T> {
inner: T,
ended: bool,
}
impl<T> Endable<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
ended: false,
}
}
pub fn end(&mut self) {
self.ended = true;
}
pub fn into_inner(self) -> T {
self.inner
}
#[inline(always)]
pub fn try_as_mut(&mut self) -> Result<&mut T, Error> {
if self.ended {
Err(Error::RootElementEnded)
} else {
Ok(&mut self.inner)
}
}
}

View File

@ -24,6 +24,7 @@ pub enum DeserializeError {
} }
#[derive(Debug)] #[derive(Debug)]
// TODO: thiserror
pub enum Error { pub enum Error {
ReadError(std::io::Error), ReadError(std::io::Error),
Utf8Error(Utf8Error), Utf8Error(Utf8Error),
@ -41,6 +42,8 @@ pub enum Error {
IncorrectName(Name), IncorrectName(Name),
DeserializeError(String), DeserializeError(String),
Deserialize(DeserializeError), Deserialize(DeserializeError),
/// root element end tag already processed
RootElementEnded,
} }
impl From<DeserializeError> for Error { impl From<DeserializeError> for Error {

View File

@ -1,5 +1,6 @@
pub mod declaration; pub mod declaration;
pub mod element; pub mod element;
mod endable;
mod error; mod error;
pub mod reader; pub mod reader;
mod writer; mod writer;

View File

@ -30,6 +30,7 @@ pub struct Reader<R> {
// to have names reference namespaces could // to have names reference namespaces could
depth: Vec<Name>, depth: Vec<Name>,
namespace_declarations: Vec<HashSet<NamespaceDeclaration>>, namespace_declarations: Vec<HashSet<NamespaceDeclaration>>,
root_ended: bool,
} }
impl<R> Reader<R> { impl<R> Reader<R> {
@ -49,6 +50,7 @@ impl<R> Reader<R> {
depth: Vec::new(), depth: Vec::new(),
// TODO: make sure reserved namespaces are never overwritten // TODO: make sure reserved namespaces are never overwritten
namespace_declarations: vec![default_declarations], namespace_declarations: vec![default_declarations],
root_ended: false,
} }
} }
@ -66,6 +68,9 @@ where
} }
pub async fn read_prolog<'s>(&'s mut self) -> Result<Option<Declaration>> { pub async fn read_prolog<'s>(&'s mut self) -> Result<Option<Declaration>> {
if self.root_ended {
return Err(Error::RootElementEnded);
}
loop { loop {
let input = str::from_utf8(self.buffer.data())?; let input = str::from_utf8(self.buffer.data())?;
match xml::Prolog::parse(input) { match xml::Prolog::parse(input) {
@ -114,6 +119,9 @@ where
} }
pub async fn read_start_tag<'s>(&'s mut self) -> Result<Element> { pub async fn read_start_tag<'s>(&'s mut self) -> Result<Element> {
if self.root_ended {
return Err(Error::RootElementEnded);
}
loop { loop {
let input = str::from_utf8(self.buffer.data())?; let input = str::from_utf8(self.buffer.data())?;
match xml::STag::parse(input) { match xml::STag::parse(input) {
@ -140,6 +148,9 @@ where
} }
pub async fn read_end_tag<'s>(&'s mut self) -> Result<()> { pub async fn read_end_tag<'s>(&'s mut self) -> Result<()> {
if self.root_ended {
return Err(Error::RootElementEnded);
}
loop { loop {
let input = str::from_utf8(self.buffer.data())?; let input = str::from_utf8(self.buffer.data())?;
match xml::ETag::parse(input) { match xml::ETag::parse(input) {
@ -150,6 +161,9 @@ where
&mut self.namespace_declarations, &mut self.namespace_declarations,
e, e,
)?; )?;
if self.depth.is_empty() {
self.root_ended = true
}
self.buffer.consume(len); self.buffer.consume(len);
return Ok(()); return Ok(());
} }
@ -166,6 +180,9 @@ where
} }
pub async fn read_element<'s>(&'s mut self) -> Result<Element> { pub async fn read_element<'s>(&'s mut self) -> Result<Element> {
if self.root_ended {
return Err(Error::RootElementEnded);
}
loop { loop {
let input = str::from_utf8(self.buffer.data())?; let input = str::from_utf8(self.buffer.data())?;
match xml::Element::parse(input) { match xml::Element::parse(input) {
@ -173,6 +190,9 @@ where
let len = self.buffer.available_data() - rest.as_bytes().len(); let len = self.buffer.available_data() - rest.as_bytes().len();
let element = let element =
Reader::<R>::element_from_xml(&mut self.namespace_declarations, e)?; Reader::<R>::element_from_xml(&mut self.namespace_declarations, e)?;
if self.depth.is_empty() {
self.root_ended = true
}
self.buffer.consume(len); self.buffer.consume(len);
return Ok(element); return Ok(element);
} }
@ -189,6 +209,9 @@ where
} }
pub async fn read_content<'s>(&'s mut self) -> Result<Content> { pub async fn read_content<'s>(&'s mut self) -> Result<Content> {
if self.root_ended {
return Err(Error::RootElementEnded);
}
let mut last_char = false; let mut last_char = false;
let mut text = String::new(); let mut text = String::new();
loop { loop {
@ -217,6 +240,9 @@ where
&mut self.namespace_declarations, &mut self.namespace_declarations,
element, element,
)?; )?;
if self.depth.is_empty() {
self.root_ended = true
}
self.buffer.consume(len); self.buffer.consume(len);
return Ok(Content::Element(element)); return Ok(Content::Element(element));
} }
@ -279,6 +305,9 @@ where
&mut self.namespace_declarations, &mut self.namespace_declarations,
element, element,
)?; )?;
if self.depth.is_empty() {
self.root_ended = true
}
self.buffer.consume(len); self.buffer.consume(len);
return Ok(Content::Element(element)); return Ok(Content::Element(element));
} }
@ -722,7 +751,7 @@ impl<R> Reader<R> {
} }
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub mod test {
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
use super::Reader; use super::Reader;

View File

@ -13,6 +13,7 @@ use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::{ use crate::{
declaration::{Declaration, VersionInfo}, declaration::{Declaration, VersionInfo},
element::{escape_str, Content, Element, IntoContent, IntoElement, Name, NamespaceDeclaration}, element::{escape_str, Content, Element, IntoContent, IntoElement, Name, NamespaceDeclaration},
endable::Endable,
error::Error, error::Error,
xml::{self, composers::Composer, parsers_complete::Parser, ETag, XMLDecl}, xml::{self, composers::Composer, parsers_complete::Parser, ETag, XMLDecl},
Result, XMLNS_NS, XML_NS, Result, XMLNS_NS, XML_NS,
@ -21,7 +22,7 @@ use crate::{
// pub struct Writer<W, C = Composer> { // pub struct Writer<W, C = Composer> {
#[derive(Debug)] #[derive(Debug)]
pub struct Writer<W> { pub struct Writer<W> {
inner: W, inner: Endable<W>,
depth: Vec<Name>, depth: Vec<Name>,
namespace_declarations: Vec<HashSet<NamespaceDeclaration>>, namespace_declarations: Vec<HashSet<NamespaceDeclaration>>,
} }
@ -38,19 +39,20 @@ impl<W> Writer<W> {
namespace: XMLNS_NS.to_string(), namespace: XMLNS_NS.to_string(),
}); });
Self { Self {
inner: writer, inner: Endable::new(writer),
depth: Vec::new(), depth: Vec::new(),
namespace_declarations: vec![default_declarations], namespace_declarations: vec![default_declarations],
} }
} }
pub fn into_inner(self) -> W { pub fn into_inner(self) -> W {
self.inner self.inner.into_inner()
} }
} }
impl<W: AsyncWrite + Unpin + Send> Writer<W> { impl<W: AsyncWrite + Unpin + Send> Writer<W> {
pub async fn write_declaration(&mut self, version: VersionInfo) -> Result<()> { pub async fn write_declaration(&mut self, version: VersionInfo) -> Result<()> {
let writer = self.inner.try_as_mut()?;
let declaration = Declaration::version(version); let declaration = Declaration::version(version);
let version_info; let version_info;
match declaration.version_info { match declaration.version_info {
@ -64,7 +66,7 @@ impl<W: AsyncWrite + Unpin + Send> Writer<W> {
encoding_decl: None, encoding_decl: None,
sd_decl: None, sd_decl: None,
}; };
declaration.write(&mut self.inner).await?; declaration.write(writer).await?;
Ok(()) Ok(())
} }
@ -105,6 +107,7 @@ impl<W: AsyncWrite + Unpin + Send> Writer<W> {
} }
pub async fn write_empty(&mut self, element: &Element) -> Result<()> { pub async fn write_empty(&mut self, element: &Element) -> Result<()> {
let writer = self.inner.try_as_mut()?;
let mut namespace_declarations_stack: Vec<_> = self let mut namespace_declarations_stack: Vec<_> = self
.namespace_declarations .namespace_declarations
.iter() .iter()
@ -204,12 +207,17 @@ impl<W: AsyncWrite + Unpin + Send> Writer<W> {
let tag = xml::EmptyElemTag { name, attributes }; let tag = xml::EmptyElemTag { name, attributes };
tag.write(&mut self.inner).await?; tag.write(writer).await?;
if self.depth.is_empty() {
self.inner.end();
}
Ok(()) Ok(())
} }
pub async fn write_element_start(&mut self, element: &Element) -> Result<()> { pub async fn write_element_start(&mut self, element: &Element) -> Result<()> {
let writer = self.inner.try_as_mut()?;
let mut namespace_declarations_stack: Vec<_> = self let mut namespace_declarations_stack: Vec<_> = self
.namespace_declarations .namespace_declarations
.iter() .iter()
@ -309,7 +317,7 @@ impl<W: AsyncWrite + Unpin + Send> Writer<W> {
let s_tag = xml::STag { name, attributes }; let s_tag = xml::STag { name, attributes };
s_tag.write(&mut self.inner).await?; s_tag.write(writer).await?;
self.depth.push(element.name.clone()); self.depth.push(element.name.clone());
self.namespace_declarations self.namespace_declarations
@ -320,7 +328,12 @@ impl<W: AsyncWrite + Unpin + Send> Writer<W> {
pub async fn write_content(&mut self, content: &Content) -> Result<()> { pub async fn write_content(&mut self, content: &Content) -> Result<()> {
match content { match content {
Content::Element(element) => self.write_element(element).await?, Content::Element(element) => self.write_element(element).await?,
Content::Text(text) => self.inner.write_all(escape_str(text).as_bytes()).await?, Content::Text(text) => {
self.inner
.try_as_mut()?
.write_all(escape_str(text).as_bytes())
.await?
}
// TODO: comments and PI // TODO: comments and PI
Content::PI => {} Content::PI => {}
Content::Comment(_) => {} Content::Comment(_) => {}
@ -329,6 +342,7 @@ impl<W: AsyncWrite + Unpin + Send> Writer<W> {
} }
pub async fn write_end(&mut self) -> Result<()> { pub async fn write_end(&mut self) -> Result<()> {
let writer = self.inner.try_as_mut()?;
if let Some(name) = &self.depth.pop() { if let Some(name) = &self.depth.pop() {
let e_tag; let e_tag;
let namespace_declarations_stack: Vec<_> = let namespace_declarations_stack: Vec<_> =
@ -359,8 +373,12 @@ impl<W: AsyncWrite + Unpin + Send> Writer<W> {
)?), )?),
}; };
} }
e_tag.write(&mut self.inner).await?; e_tag.write(writer).await?;
self.namespace_declarations.pop(); self.namespace_declarations.pop();
if self.depth.is_empty() {
self.inner.end();
}
Ok(()) Ok(())
} else { } else {
return Err(Error::NotInElement("".to_string())); return Err(Error::NotInElement("".to_string()));