implement sink and stream properly UNFOLD UNFOLD

This commit is contained in:
cel 🌸 2024-12-06 06:31:20 +00:00
parent aaf34b5bca
commit 595d165479
6 changed files with 181 additions and 35 deletions

View File

@ -4,6 +4,7 @@
## TODO:
- [ ] how to know if stanza has been sent
- [ ] error states for all negotiation parts
- [ ] better errors
- [x] rename structs
@ -74,6 +75,7 @@ need more research:
- [ ] message editing
- [ ] xep-0308: last message correction (should not be used for older than last message according to spec)
- [ ] chat read markers
- [ ] xep-0490: message displayed synchronization
- [ ] xep-0333: displayed markers
- [ ] message styling
- [ ] xep-0393: message styling

View File

@ -22,6 +22,9 @@ stanza = { version = "0.1.0", path = "../stanza" }
peanuts = { version = "0.1.0", path = "../../peanuts" }
jid = { version = "0.1.0", path = "../jid" }
futures = "0.3.31"
take_mut = "0.2.2"
pin-project-lite = "0.2.15"
pin-project = "1.1.7"
[dev-dependencies]
test-log = { version = "0.2", features = ["trace"] }

View File

@ -56,6 +56,14 @@ impl JabberClient {
}
}
pub(crate) fn inner(self) -> Result<JabberStream<Tls>> {
match self.connection {
ConnectionState::Disconnected => return Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream),
}
}
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
match &mut self.connection {
ConnectionState::Disconnected => return Err(Error::Disconnected),
@ -67,22 +75,6 @@ impl JabberClient {
}
}
impl Stream for JabberClient {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut client = pin!(self);
match &mut client.connection {
ConnectionState::Disconnected => Poll::Pending,
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
}
}
}
pub enum ConnectionState {
Disconnected,
Connecting(Connecting),

View File

@ -16,6 +16,7 @@ pub enum Error {
Unsupported,
NoLocalpart,
AlreadyConnecting,
StreamClosed,
UnexpectedElement(peanuts::Element),
XML(peanuts::Error),
Deserialization(peanuts::DeserializeError),

View File

@ -2,7 +2,7 @@ use std::pin::pin;
use std::str::{self, FromStr};
use std::sync::Arc;
use futures::StreamExt;
use futures::{sink, stream, StreamExt};
use jid::JID;
use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer};
@ -22,28 +22,14 @@ use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::Result;
pub mod bound_stream;
// open stream (streams started)
pub struct JabberStream<S> {
reader: Reader<ReadHalf<S>>,
writer: Writer<WriteHalf<S>>,
}
impl<S: AsyncRead> futures::Stream for JabberStream<S> {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
pin!(self).reader.poll_next_unpin(cx).map(|content| {
content.map(|content| -> Result<Stanza> {
let stanza = content.map(|content| Stanza::from_content(content))?;
Ok(stanza?)
})
})
}
}
impl<S> JabberStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
@ -327,7 +313,8 @@ mod tests {
use std::time::Duration;
use super::*;
use crate::connection::Connection;
use crate::{client::JabberClient, connection::Connection};
use futures::sink;
use test_log::test;
use tokio::time::sleep;
@ -378,7 +365,15 @@ mod tests {
}
#[tokio::test]
async fn negotiate() {
async fn sink() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
let stream = client.inner().unwrap();
let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move {
stream.writer.write(&stanza).await?;
Ok::<JabberStream<Tls>, Error>(stream)
});
todo!()
// let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
// .await
// .unwrap()

View File

@ -0,0 +1,153 @@
use std::pin::pin;
use std::pin::Pin;
use std::sync::Arc;
use futures::{sink, stream, Sink, Stream};
use peanuts::{Reader, Writer};
use pin_project::pin_project;
use stanza::client::Stanza;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::Mutex;
use crate::Error;
use super::JabberStream;
#[pin_project]
pub struct BoundJabberStream<R, W, S>
where
R: Stream,
W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>,
writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>,
stream: R,
sink: W,
}
impl<R, W, S> BoundJabberStream<R, W, S>
where
R: Stream,
W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
// TODO: look into biased mutex, to close stream ASAP
pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
if let Some(reader) = self.reader.lock().await.take() {
if let Some(writer) = self.writer.lock().await.take() {
// TODO: writer </stream:stream>
return Ok(JabberStream { reader, writer });
}
}
return Err(Error::StreamClosed);
}
}
pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}
impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S>
where
R: Stream,
W: Sink<Stanza> + Unpin,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
type Error = <W as Sink<Stanza>>::Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.project();
pin!(this.sink).poll_ready(cx)
}
fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
let this = self.project();
pin!(this.sink).start_send(item)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.project();
pin!(this.sink).poll_flush(cx)
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.project();
pin!(this.sink).poll_close(cx)
}
}
impl<R, W, S> Stream for BoundJabberStream<R, W, S>
where
R: Stream + Unpin,
W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
type Item = <R as Stream>::Item;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
pin!(this.stream).poll_next(cx)
}
}
impl<S> JabberStream<S>
where
S: AsyncWrite + AsyncRead + Unpin + Send,
{
pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> {
let reader = Arc::new(Mutex::new(Some(self.reader)));
let writer = Arc::new(Mutex::new(Some(self.writer)));
let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move {
write(writer, s).await
});
let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await });
BoundJabberStream {
sink,
stream,
writer,
reader,
}
}
}
pub async fn write<W: AsyncWrite + Unpin + Send>(
writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>,
stanza: Stanza,
) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> {
{
if let Some(writer) = writer.lock().await.as_mut() {
writer.write(&stanza).await?;
} else {
return Err(Error::StreamClosed);
}
}
Ok(writer)
}
pub async fn read<R: AsyncRead + Unpin + Send>(
reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
) -> Option<(
Result<Stanza, Error>,
Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
)> {
let stanza: Result<Stanza, Error>;
{
if let Some(reader) = reader.lock().await.as_mut() {
stanza = reader.read().await.map_err(|e| e.into());
} else {
stanza = Err(Error::StreamClosed)
};
}
Some((stanza, reader))
}