use std::future::ready; use std::pin::pin; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; use futures::ready; use futures::FutureExt; use futures::{sink, stream, Sink, Stream}; use peanuts::{Reader, Writer}; use pin_project::pin_project; use stanza::client::Stanza; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio::sync::Mutex; use tokio::task::JoinHandle; use crate::Error; use super::JabberStream; #[pin_project] pub struct BoundJabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { reader: Arc>>>, writer: Arc>>>, write_handle: Option>>, read_handle: Option>>, } impl BoundJabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { // TODO: look into biased mutex, to close stream ASAP // TODO: put into connection // pub async fn close_stream(self) -> Result, Error> { // let reader = self.reader.lock().await.into_self(); // let writer = self.writer.lock().await.into_self(); // // TODO: writer // return Ok(JabberStream { reader, writer }); // } } pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} impl Sink for BoundJabberStream where S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { type Error = Error; fn poll_ready( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.poll_flush(cx) } fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { let this = self.project(); if let Some(_write_handle) = this.write_handle { panic!("start_send called without poll_ready") } else { // TODO: switch to buffer of one rather than thread spawning and joining *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item))); Ok(()) } } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() { match ready!(join_handle.poll_unpin(cx)) { Ok(state) => { *this.write_handle = None; state } Err(err) => { *this.write_handle = None; Err(err.into()) } } } else { Ok(()) }) } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.poll_flush(cx) } } impl Stream for BoundJabberStream where S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { type Item = Result; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); loop { if let Some(join_handle) = this.read_handle.as_mut() { let stanza = ready!(join_handle.poll_unpin(cx)); if let Ok(item) = stanza { *this.read_handle = None; return Poll::Ready(Some(item)); } else if let Err(err) = stanza { return Poll::Ready(Some(Err(err.into()))); } } else { *this.read_handle = Some(tokio::spawn(read(this.reader.clone()))) } } } } impl JabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { pub fn to_bound_jabber(self) -> BoundJabberStream { let reader = Arc::new(Mutex::new(self.reader)); let writer = Arc::new(Mutex::new(self.writer)); BoundJabberStream { writer, reader, write_handle: None, read_handle: None, } } } pub async fn write( writer: Arc>>>, stanza: Stanza, ) -> Result<(), Error> { { let mut writer = writer.lock().await; writer.write(&stanza).await?; } Ok(()) } pub async fn read( reader: Arc>>>, ) -> Result { let stanza: Result; { let mut reader = reader.lock().await; stanza = reader.read().await.map_err(|e| e.into()); } stanza }