use std::pin::pin; use std::pin::Pin; use std::sync::Arc; use futures::{sink, stream, Sink, Stream}; use peanuts::{Reader, Writer}; use pin_project::pin_project; use stanza::client::Stanza; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio::sync::Mutex; use crate::Error; use super::JabberStream; #[pin_project] pub struct BoundJabberStream where R: Stream, W: Sink, S: AsyncWrite + AsyncRead + Unpin + Send, { reader: Arc>>>>, writer: Arc>>>>, stream: R, sink: W, } impl BoundJabberStream where R: Stream, W: Sink, S: AsyncWrite + AsyncRead + Unpin + Send, { // TODO: look into biased mutex, to close stream ASAP pub async fn close_stream(self) -> Result, Error> { if let Some(reader) = self.reader.lock().await.take() { if let Some(writer) = self.writer.lock().await.take() { // TODO: writer return Ok(JabberStream { reader, writer }); } } return Err(Error::StreamClosed); } } pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} impl Sink for BoundJabberStream where R: Stream, W: Sink + Unpin, S: AsyncWrite + AsyncRead + Unpin + Send, { type Error = >::Error; fn poll_ready( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); pin!(this.sink).poll_ready(cx) } fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { let this = self.project(); pin!(this.sink).start_send(item) } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); pin!(this.sink).poll_flush(cx) } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); pin!(this.sink).poll_close(cx) } } impl Stream for BoundJabberStream where R: Stream + Unpin, W: Sink, S: AsyncWrite + AsyncRead + Unpin + Send, { type Item = ::Item; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); pin!(this.stream).poll_next(cx) } } impl JabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { pub fn to_bound_jabber(self) -> BoundJabberStream, S> { let reader = Arc::new(Mutex::new(Some(self.reader))); let writer = Arc::new(Mutex::new(Some(self.writer))); let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move { write(writer, s).await }); let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await }); BoundJabberStream { sink, stream, writer, reader, } } } pub async fn write( writer: Arc>>>>, stanza: Stanza, ) -> Result>>>>, Error> { { if let Some(writer) = writer.lock().await.as_mut() { writer.write(&stanza).await?; } else { return Err(Error::StreamClosed); } } Ok(writer) } pub async fn read( reader: Arc>>>>, ) -> Option<( Result, Arc>>>>, )> { let stanza: Result; { if let Some(reader) = reader.lock().await.as_mut() { stanza = reader.read().await.map_err(|e| e.into()); } else { stanza = Err(Error::StreamClosed) }; } Some((stanza, reader)) }