From 595d165479b8b12e456f39205d8433b822b07487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Fri, 6 Dec 2024 06:31:20 +0000 Subject: [PATCH] implement sink and stream properly UNFOLD UNFOLD --- README.md | 2 + jabber/Cargo.toml | 3 + jabber/src/client.rs | 24 ++-- jabber/src/error.rs | 1 + jabber/src/jabber_stream.rs | 33 +++-- jabber/src/jabber_stream/bound_stream.rs | 153 +++++++++++++++++++++++ 6 files changed, 181 insertions(+), 35 deletions(-) create mode 100644 jabber/src/jabber_stream/bound_stream.rs diff --git a/README.md b/README.md index 44632dd..63094ae 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ ## TODO: +- [ ] how to know if stanza has been sent - [ ] error states for all negotiation parts - [ ] better errors - [x] rename structs @@ -74,6 +75,7 @@ need more research: - [ ] message editing - [ ] xep-0308: last message correction (should not be used for older than last message according to spec) - [ ] chat read markers + - [ ] xep-0490: message displayed synchronization - [ ] xep-0333: displayed markers - [ ] message styling - [ ] xep-0393: message styling diff --git a/jabber/Cargo.toml b/jabber/Cargo.toml index 4753e59..68dddd9 100644 --- a/jabber/Cargo.toml +++ b/jabber/Cargo.toml @@ -22,6 +22,9 @@ stanza = { version = "0.1.0", path = "../stanza" } peanuts = { version = "0.1.0", path = "../../peanuts" } jid = { version = "0.1.0", path = "../jid" } futures = "0.3.31" +take_mut = "0.2.2" +pin-project-lite = "0.2.15" +pin-project = "1.1.7" [dev-dependencies] test-log = { version = "0.2", features = ["trace"] } diff --git a/jabber/src/client.rs b/jabber/src/client.rs index c8b0b73..c6cab07 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -56,6 +56,14 @@ impl JabberClient { } } + pub(crate) fn inner(self) -> Result> { + match self.connection { + ConnectionState::Disconnected => return Err(Error::Disconnected), + ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), + ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream), + } + } + pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { match &mut self.connection { ConnectionState::Disconnected => return Err(Error::Disconnected), @@ -67,22 +75,6 @@ impl JabberClient { } } -impl Stream for JabberClient { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let mut client = pin!(self); - match &mut client.connection { - ConnectionState::Disconnected => Poll::Pending, - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx), - } - } -} - pub enum ConnectionState { Disconnected, Connecting(Connecting), diff --git a/jabber/src/error.rs b/jabber/src/error.rs index aad033c..6671fe6 100644 --- a/jabber/src/error.rs +++ b/jabber/src/error.rs @@ -16,6 +16,7 @@ pub enum Error { Unsupported, NoLocalpart, AlreadyConnecting, + StreamClosed, UnexpectedElement(peanuts::Element), XML(peanuts::Error), Deserialization(peanuts::DeserializeError), diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs index dd0dcbf..d981f8f 100644 --- a/jabber/src/jabber_stream.rs +++ b/jabber/src/jabber_stream.rs @@ -2,7 +2,7 @@ use std::pin::pin; use std::str::{self, FromStr}; use std::sync::Arc; -use futures::StreamExt; +use futures::{sink, stream, StreamExt}; use jid::JID; use peanuts::element::{FromContent, IntoElement}; use peanuts::{Reader, Writer}; @@ -22,28 +22,14 @@ use crate::connection::{Tls, Unencrypted}; use crate::error::Error; use crate::Result; +pub mod bound_stream; + // open stream (streams started) pub struct JabberStream { reader: Reader>, writer: Writer>, } -impl futures::Stream for JabberStream { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - pin!(self).reader.poll_next_unpin(cx).map(|content| { - content.map(|content| -> Result { - let stanza = content.map(|content| Stanza::from_content(content))?; - Ok(stanza?) - }) - }) - } -} - impl JabberStream where S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, @@ -327,7 +313,8 @@ mod tests { use std::time::Duration; use super::*; - use crate::connection::Connection; + use crate::{client::JabberClient, connection::Connection}; + use futures::sink; use test_log::test; use tokio::time::sleep; @@ -378,7 +365,15 @@ mod tests { } #[tokio::test] - async fn negotiate() { + async fn sink() { + let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); + client.connect().await.unwrap(); + let stream = client.inner().unwrap(); + let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { + stream.writer.write(&stanza).await?; + Ok::, Error>(stream) + }); + todo!() // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) // .await // .unwrap() diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs new file mode 100644 index 0000000..ca93421 --- /dev/null +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -0,0 +1,153 @@ +use std::pin::pin; +use std::pin::Pin; +use std::sync::Arc; + +use futures::{sink, stream, Sink, Stream}; +use peanuts::{Reader, Writer}; +use pin_project::pin_project; +use stanza::client::Stanza; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio::sync::Mutex; + +use crate::Error; + +use super::JabberStream; + +#[pin_project] +pub struct BoundJabberStream +where + R: Stream, + W: Sink, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + reader: Arc>>>>, + writer: Arc>>>>, + stream: R, + sink: W, +} + +impl BoundJabberStream +where + R: Stream, + W: Sink, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + // TODO: look into biased mutex, to close stream ASAP + pub async fn close_stream(self) -> Result, Error> { + if let Some(reader) = self.reader.lock().await.take() { + if let Some(writer) = self.writer.lock().await.take() { + // TODO: writer + return Ok(JabberStream { reader, writer }); + } + } + return Err(Error::StreamClosed); + } +} + +pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} + +impl Sink for BoundJabberStream +where + R: Stream, + W: Sink + Unpin, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + type Error = >::Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + pin!(this.sink).poll_ready(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { + let this = self.project(); + pin!(this.sink).start_send(item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + pin!(this.sink).poll_flush(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + pin!(this.sink).poll_close(cx) + } +} + +impl Stream for BoundJabberStream +where + R: Stream + Unpin, + W: Sink, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + type Item = ::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + pin!(this.stream).poll_next(cx) + } +} + +impl JabberStream +where + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + pub fn to_bound_jabber(self) -> BoundJabberStream, S> { + let reader = Arc::new(Mutex::new(Some(self.reader))); + let writer = Arc::new(Mutex::new(Some(self.writer))); + let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move { + write(writer, s).await + }); + let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await }); + BoundJabberStream { + sink, + stream, + writer, + reader, + } + } +} + +pub async fn write( + writer: Arc>>>>, + stanza: Stanza, +) -> Result>>>>, Error> { + { + if let Some(writer) = writer.lock().await.as_mut() { + writer.write(&stanza).await?; + } else { + return Err(Error::StreamClosed); + } + } + Ok(writer) +} + +pub async fn read( + reader: Arc>>>>, +) -> Option<( + Result, + Arc>>>>, +)> { + let stanza: Result; + { + if let Some(reader) = reader.lock().await.as_mut() { + stanza = reader.read().await.map_err(|e| e.into()); + } else { + stanza = Err(Error::StreamClosed) + }; + } + Some((stanza, reader)) +}