154 lines
4.1 KiB
Rust
154 lines
4.1 KiB
Rust
|
use std::pin::pin;
|
||
|
use std::pin::Pin;
|
||
|
use std::sync::Arc;
|
||
|
|
||
|
use futures::{sink, stream, Sink, Stream};
|
||
|
use peanuts::{Reader, Writer};
|
||
|
use pin_project::pin_project;
|
||
|
use stanza::client::Stanza;
|
||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
|
||
|
use tokio::sync::Mutex;
|
||
|
|
||
|
use crate::Error;
|
||
|
|
||
|
use super::JabberStream;
|
||
|
|
||
|
#[pin_project]
|
||
|
pub struct BoundJabberStream<R, W, S>
|
||
|
where
|
||
|
R: Stream,
|
||
|
W: Sink<Stanza>,
|
||
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||
|
{
|
||
|
reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>,
|
||
|
writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>,
|
||
|
stream: R,
|
||
|
sink: W,
|
||
|
}
|
||
|
|
||
|
impl<R, W, S> BoundJabberStream<R, W, S>
|
||
|
where
|
||
|
R: Stream,
|
||
|
W: Sink<Stanza>,
|
||
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||
|
{
|
||
|
// TODO: look into biased mutex, to close stream ASAP
|
||
|
pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
|
||
|
if let Some(reader) = self.reader.lock().await.take() {
|
||
|
if let Some(writer) = self.writer.lock().await.take() {
|
||
|
// TODO: writer </stream:stream>
|
||
|
return Ok(JabberStream { reader, writer });
|
||
|
}
|
||
|
}
|
||
|
return Err(Error::StreamClosed);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}
|
||
|
|
||
|
impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S>
|
||
|
where
|
||
|
R: Stream,
|
||
|
W: Sink<Stanza> + Unpin,
|
||
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||
|
{
|
||
|
type Error = <W as Sink<Stanza>>::Error;
|
||
|
|
||
|
fn poll_ready(
|
||
|
self: std::pin::Pin<&mut Self>,
|
||
|
cx: &mut std::task::Context<'_>,
|
||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||
|
let this = self.project();
|
||
|
pin!(this.sink).poll_ready(cx)
|
||
|
}
|
||
|
|
||
|
fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
|
||
|
let this = self.project();
|
||
|
pin!(this.sink).start_send(item)
|
||
|
}
|
||
|
|
||
|
fn poll_flush(
|
||
|
self: std::pin::Pin<&mut Self>,
|
||
|
cx: &mut std::task::Context<'_>,
|
||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||
|
let this = self.project();
|
||
|
pin!(this.sink).poll_flush(cx)
|
||
|
}
|
||
|
|
||
|
fn poll_close(
|
||
|
self: std::pin::Pin<&mut Self>,
|
||
|
cx: &mut std::task::Context<'_>,
|
||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||
|
let this = self.project();
|
||
|
pin!(this.sink).poll_close(cx)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl<R, W, S> Stream for BoundJabberStream<R, W, S>
|
||
|
where
|
||
|
R: Stream + Unpin,
|
||
|
W: Sink<Stanza>,
|
||
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||
|
{
|
||
|
type Item = <R as Stream>::Item;
|
||
|
|
||
|
fn poll_next(
|
||
|
self: Pin<&mut Self>,
|
||
|
cx: &mut std::task::Context<'_>,
|
||
|
) -> std::task::Poll<Option<Self::Item>> {
|
||
|
let this = self.project();
|
||
|
pin!(this.stream).poll_next(cx)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl<S> JabberStream<S>
|
||
|
where
|
||
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||
|
{
|
||
|
pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> {
|
||
|
let reader = Arc::new(Mutex::new(Some(self.reader)));
|
||
|
let writer = Arc::new(Mutex::new(Some(self.writer)));
|
||
|
let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move {
|
||
|
write(writer, s).await
|
||
|
});
|
||
|
let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await });
|
||
|
BoundJabberStream {
|
||
|
sink,
|
||
|
stream,
|
||
|
writer,
|
||
|
reader,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub async fn write<W: AsyncWrite + Unpin + Send>(
|
||
|
writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>,
|
||
|
stanza: Stanza,
|
||
|
) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> {
|
||
|
{
|
||
|
if let Some(writer) = writer.lock().await.as_mut() {
|
||
|
writer.write(&stanza).await?;
|
||
|
} else {
|
||
|
return Err(Error::StreamClosed);
|
||
|
}
|
||
|
}
|
||
|
Ok(writer)
|
||
|
}
|
||
|
|
||
|
pub async fn read<R: AsyncRead + Unpin + Send>(
|
||
|
reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
|
||
|
) -> Option<(
|
||
|
Result<Stanza, Error>,
|
||
|
Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
|
||
|
)> {
|
||
|
let stanza: Result<Stanza, Error>;
|
||
|
{
|
||
|
if let Some(reader) = reader.lock().await.as_mut() {
|
||
|
stanza = reader.read().await.map_err(|e| e.into());
|
||
|
} else {
|
||
|
stanza = Err(Error::StreamClosed)
|
||
|
};
|
||
|
}
|
||
|
Some((stanza, reader))
|
||
|
}
|