implement stream splitting and closing
This commit is contained in:
parent
0e5f09b2bd
commit
e6c97ab828
|
@ -12,7 +12,12 @@ async-trait = "0.1.68"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
nanoid = "0.4.0"
|
nanoid = "0.4.0"
|
||||||
# TODO: remove unneeded features
|
# TODO: remove unneeded features
|
||||||
rsasl = { version = "2.0.1", path = "../../rsasl", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] }
|
rsasl = { version = "2.0.1", default_features = false, features = [
|
||||||
|
"provider_base64",
|
||||||
|
"plain",
|
||||||
|
"config_builder",
|
||||||
|
"scram-sha-1",
|
||||||
|
] }
|
||||||
tokio = { version = "1.28", features = ["full"] }
|
tokio = { version = "1.28", features = ["full"] }
|
||||||
tokio-native-tls = "0.3.1"
|
tokio-native-tls = "0.3.1"
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
|
@ -29,4 +34,7 @@ pin-project = "1.1.7"
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
test-log = { version = "0.2", features = ["trace"] }
|
test-log = { version = "0.2", features = ["trace"] }
|
||||||
env_logger = "*"
|
env_logger = "*"
|
||||||
tracing-subscriber = {version = "0.3", default-features = false, features = ["env-filter", "fmt"]}
|
tracing-subscriber = { version = "0.3", default-features = false, features = [
|
||||||
|
"env-filter",
|
||||||
|
"fmt",
|
||||||
|
] }
|
||||||
|
|
|
@ -18,13 +18,13 @@ use tokio::sync::Mutex;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
connection::{Tls, Unencrypted},
|
connection::{Tls, Unencrypted},
|
||||||
jabber_stream::bound_stream::BoundJabberStream,
|
jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream},
|
||||||
Connection, Error, JabberStream, Result, JID,
|
Connection, Error, JabberStream, Result, JID,
|
||||||
};
|
};
|
||||||
|
|
||||||
// feed it client stanzas, receive client stanzas
|
// feed it client stanzas, receive client stanzas
|
||||||
pub struct JabberClient {
|
pub struct JabberClient {
|
||||||
connection: ConnectionState,
|
connection: Option<BoundJabberStream<Tls>>,
|
||||||
jid: JID,
|
jid: JID,
|
||||||
// TODO: have reconnection be handled by another part, so creds don't need to be stored in object
|
// TODO: have reconnection be handled by another part, so creds don't need to be stored in object
|
||||||
password: Arc<SASLConfig>,
|
password: Arc<SASLConfig>,
|
||||||
|
@ -43,7 +43,7 @@ impl JabberClient {
|
||||||
password.to_string(),
|
password.to_string(),
|
||||||
)?;
|
)?;
|
||||||
Ok(JabberClient {
|
Ok(JabberClient {
|
||||||
connection: ConnectionState::Disconnected,
|
connection: None,
|
||||||
jid: jid.clone(),
|
jid: jid.clone(),
|
||||||
password: sasl_config,
|
password: sasl_config,
|
||||||
server: jid.domainpart,
|
server: jid.domainpart,
|
||||||
|
@ -56,25 +56,19 @@ impl JabberClient {
|
||||||
|
|
||||||
pub async fn connect(&mut self) -> Result<()> {
|
pub async fn connect(&mut self) -> Result<()> {
|
||||||
match &self.connection {
|
match &self.connection {
|
||||||
ConnectionState::Disconnected => {
|
Some(_) => Ok(()),
|
||||||
// TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
|
None => {
|
||||||
// perhaps use take_mut?
|
self.connection = Some(
|
||||||
self.connection = ConnectionState::Disconnected
|
connect_and_login(&mut self.jid, self.password.clone(), &mut self.server)
|
||||||
.connect(&mut self.jid, self.password.clone(), &mut self.server)
|
.await?,
|
||||||
.await?;
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
|
|
||||||
ConnectionState::Connected(_jabber_stream) => Ok(()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> {
|
pub(crate) fn into_inner(self) -> Result<BoundJabberStream<Tls>> {
|
||||||
match self.connection {
|
self.connection.ok_or(Error::Disconnected)
|
||||||
ConnectionState::Disconnected => return Err(Error::Disconnected),
|
|
||||||
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
|
|
||||||
ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
|
// pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
|
||||||
|
@ -88,203 +82,59 @@ impl JabberClient {
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Sink<Stanza> for JabberClient {
|
pub async fn connect_and_login(
|
||||||
type Error = Error;
|
jid: &mut JID,
|
||||||
|
auth: Arc<SASLConfig>,
|
||||||
fn poll_ready(
|
server: &mut String,
|
||||||
self: std::pin::Pin<&mut Self>,
|
) -> Result<BoundJabberStream<Tls>> {
|
||||||
cx: &mut std::task::Context<'_>,
|
let mut conn_state = Connecting::start(&server).await?;
|
||||||
) -> Poll<std::result::Result<(), Self::Error>> {
|
loop {
|
||||||
self.get_mut().connection.poll_ready_unpin(cx)
|
match conn_state {
|
||||||
}
|
Connecting::InsecureConnectionEstablised(tcp_stream) => {
|
||||||
|
conn_state = Connecting::InsecureStreamStarted(
|
||||||
fn start_send(
|
JabberStream::start_stream(tcp_stream, server).await?,
|
||||||
self: std::pin::Pin<&mut Self>,
|
)
|
||||||
item: Stanza,
|
|
||||||
) -> std::result::Result<(), Self::Error> {
|
|
||||||
self.get_mut().connection.start_send_unpin(item)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<std::result::Result<(), Self::Error>> {
|
|
||||||
self.get_mut().connection.poll_flush_unpin(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<std::result::Result<(), Self::Error>> {
|
|
||||||
self.get_mut().connection.poll_flush_unpin(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Stream for JabberClient {
|
|
||||||
type Item = Result<Stanza>;
|
|
||||||
|
|
||||||
fn poll_next(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Option<Self::Item>> {
|
|
||||||
self.get_mut().connection.poll_next_unpin(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum ConnectionState {
|
|
||||||
Disconnected,
|
|
||||||
Connecting(Connecting),
|
|
||||||
Connected(BoundJabberStream<Tls>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Sink<Stanza> for ConnectionState {
|
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
fn poll_ready(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<std::result::Result<(), Self::Error>> {
|
|
||||||
match self.get_mut() {
|
|
||||||
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
|
|
||||||
ConnectionState::Connecting(_connecting) => Poll::Pending,
|
|
||||||
ConnectionState::Connected(bound_jabber_stream) => {
|
|
||||||
bound_jabber_stream.poll_ready_unpin(cx)
|
|
||||||
}
|
}
|
||||||
}
|
Connecting::InsecureStreamStarted(jabber_stream) => {
|
||||||
}
|
conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?)
|
||||||
|
|
||||||
fn start_send(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
item: Stanza,
|
|
||||||
) -> std::result::Result<(), Self::Error> {
|
|
||||||
match self.get_mut() {
|
|
||||||
ConnectionState::Disconnected => Err(Error::Disconnected),
|
|
||||||
ConnectionState::Connecting(_connecting) => Err(Error::Connecting),
|
|
||||||
ConnectionState::Connected(bound_jabber_stream) => {
|
|
||||||
bound_jabber_stream.start_send_unpin(item)
|
|
||||||
}
|
}
|
||||||
}
|
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
|
||||||
}
|
match features.negotiate().ok_or(Error::Negotiation)? {
|
||||||
|
Feature::StartTls(_start_tls) => {
|
||||||
fn poll_flush(
|
conn_state = Connecting::StartTls(jabber_stream)
|
||||||
self: std::pin::Pin<&mut Self>,
|
}
|
||||||
cx: &mut std::task::Context<'_>,
|
// TODO: better error
|
||||||
) -> Poll<std::result::Result<(), Self::Error>> {
|
_ => return Err(Error::TlsRequired),
|
||||||
match self.get_mut() {
|
|
||||||
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
|
|
||||||
ConnectionState::Connecting(_connecting) => Poll::Pending,
|
|
||||||
ConnectionState::Connected(bound_jabber_stream) => {
|
|
||||||
bound_jabber_stream.poll_flush_unpin(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<std::result::Result<(), Self::Error>> {
|
|
||||||
match self.get_mut() {
|
|
||||||
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
|
|
||||||
ConnectionState::Connecting(_connecting) => Poll::Pending,
|
|
||||||
ConnectionState::Connected(bound_jabber_stream) => {
|
|
||||||
bound_jabber_stream.poll_close_unpin(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Stream for ConnectionState {
|
|
||||||
type Item = Result<Stanza>;
|
|
||||||
|
|
||||||
fn poll_next(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Option<Self::Item>> {
|
|
||||||
match self.get_mut() {
|
|
||||||
ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))),
|
|
||||||
ConnectionState::Connecting(_connecting) => Poll::Pending,
|
|
||||||
ConnectionState::Connected(bound_jabber_stream) => {
|
|
||||||
bound_jabber_stream.poll_next_unpin(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConnectionState {
|
|
||||||
pub async fn connect(
|
|
||||||
mut self,
|
|
||||||
jid: &mut JID,
|
|
||||||
auth: Arc<SASLConfig>,
|
|
||||||
server: &mut String,
|
|
||||||
) -> Result<Self> {
|
|
||||||
loop {
|
|
||||||
match self {
|
|
||||||
ConnectionState::Disconnected => {
|
|
||||||
self = ConnectionState::Connecting(Connecting::start(&server).await?);
|
|
||||||
}
|
}
|
||||||
ConnectionState::Connecting(connecting) => match connecting {
|
}
|
||||||
Connecting::InsecureConnectionEstablised(tcp_stream) => {
|
Connecting::StartTls(jabber_stream) => {
|
||||||
self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
|
conn_state =
|
||||||
JabberStream::start_stream(tcp_stream, server).await?,
|
Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?)
|
||||||
))
|
}
|
||||||
|
Connecting::ConnectionEstablished(tls_stream) => {
|
||||||
|
conn_state =
|
||||||
|
Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?)
|
||||||
|
}
|
||||||
|
Connecting::StreamStarted(jabber_stream) => {
|
||||||
|
conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?)
|
||||||
|
}
|
||||||
|
Connecting::GotFeatures((features, jabber_stream)) => {
|
||||||
|
match features.negotiate().ok_or(Error::Negotiation)? {
|
||||||
|
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
|
||||||
|
Feature::Sasl(mechanisms) => {
|
||||||
|
conn_state = Connecting::Sasl(mechanisms, jabber_stream)
|
||||||
}
|
}
|
||||||
Connecting::InsecureStreamStarted(jabber_stream) => {
|
Feature::Bind => conn_state = Connecting::Bind(jabber_stream),
|
||||||
self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
|
Feature::Unknown => return Err(Error::Unsupported),
|
||||||
jabber_stream.get_features().await?,
|
}
|
||||||
))
|
}
|
||||||
}
|
Connecting::Sasl(mechanisms, jabber_stream) => {
|
||||||
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
|
conn_state = Connecting::ConnectionEstablished(
|
||||||
match features.negotiate().ok_or(Error::Negotiation)? {
|
jabber_stream.sasl(mechanisms, auth.clone()).await?,
|
||||||
Feature::StartTls(_start_tls) => {
|
)
|
||||||
self =
|
}
|
||||||
ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
|
Connecting::Bind(jabber_stream) => {
|
||||||
}
|
return Ok(jabber_stream.bind(jid).await?.to_bound_jabber());
|
||||||
// TODO: better error
|
|
||||||
_ => return Err(Error::TlsRequired),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Connecting::StartTls(jabber_stream) => {
|
|
||||||
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
|
|
||||||
jabber_stream.starttls(&server).await?,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
Connecting::ConnectionEstablished(tls_stream) => {
|
|
||||||
self = ConnectionState::Connecting(Connecting::StreamStarted(
|
|
||||||
JabberStream::start_stream(tls_stream, server).await?,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
Connecting::StreamStarted(jabber_stream) => {
|
|
||||||
self = ConnectionState::Connecting(Connecting::GotFeatures(
|
|
||||||
jabber_stream.get_features().await?,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
Connecting::GotFeatures((features, jabber_stream)) => {
|
|
||||||
match features.negotiate().ok_or(Error::Negotiation)? {
|
|
||||||
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
|
|
||||||
Feature::Sasl(mechanisms) => {
|
|
||||||
self = ConnectionState::Connecting(Connecting::Sasl(
|
|
||||||
mechanisms,
|
|
||||||
jabber_stream,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
Feature::Bind => {
|
|
||||||
self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
|
|
||||||
}
|
|
||||||
Feature::Unknown => return Err(Error::Unsupported),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Connecting::Sasl(mechanisms, jabber_stream) => {
|
|
||||||
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
|
|
||||||
jabber_stream.sasl(mechanisms, auth.clone()).await?,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
Connecting::Bind(jabber_stream) => {
|
|
||||||
self = ConnectionState::Connected(
|
|
||||||
jabber_stream.bind(jid).await?.to_bound_jabber(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
connected => return Ok(connected),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -354,12 +204,12 @@ mod tests {
|
||||||
sleep(Duration::from_secs(5)).await;
|
sleep(Duration::from_secs(5)).await;
|
||||||
let jid = client.jid.clone();
|
let jid = client.jid.clone();
|
||||||
let server = client.server.clone();
|
let server = client.server.clone();
|
||||||
let (mut write, mut read) = client.split();
|
let (mut read, mut write) = client.into_inner().unwrap().split();
|
||||||
|
|
||||||
tokio::join!(
|
tokio::join!(
|
||||||
async {
|
async {
|
||||||
write
|
write
|
||||||
.send(Stanza::Iq(Iq {
|
.write(&Stanza::Iq(Iq {
|
||||||
from: Some(jid.clone()),
|
from: Some(jid.clone()),
|
||||||
id: "c2s1".to_string(),
|
id: "c2s1".to_string(),
|
||||||
to: Some(server.clone().try_into().unwrap()),
|
to: Some(server.clone().try_into().unwrap()),
|
||||||
|
@ -368,9 +218,10 @@ mod tests {
|
||||||
query: Some(Query::Ping(Ping)),
|
query: Some(Query::Ping(Ping)),
|
||||||
errors: Vec::new(),
|
errors: Vec::new(),
|
||||||
}))
|
}))
|
||||||
.await;
|
.await
|
||||||
|
.unwrap();
|
||||||
write
|
write
|
||||||
.send(Stanza::Iq(Iq {
|
.write(&Stanza::Iq(Iq {
|
||||||
from: Some(jid.clone()),
|
from: Some(jid.clone()),
|
||||||
id: "c2s2".to_string(),
|
id: "c2s2".to_string(),
|
||||||
to: Some(server.clone().try_into().unwrap()),
|
to: Some(server.clone().try_into().unwrap()),
|
||||||
|
@ -379,11 +230,13 @@ mod tests {
|
||||||
query: Some(Query::Ping(Ping)),
|
query: Some(Query::Ping(Ping)),
|
||||||
errors: Vec::new(),
|
errors: Vec::new(),
|
||||||
}))
|
}))
|
||||||
.await;
|
.await
|
||||||
|
.unwrap();
|
||||||
},
|
},
|
||||||
async {
|
async {
|
||||||
while let Some(stanza) = read.next().await {
|
for _ in 0..2 {
|
||||||
info!("{:#?}", stanza);
|
let stanza = read.read::<Stanza>().await.unwrap();
|
||||||
|
info!("ping reply: {:#?}", stanza);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
|
@ -26,8 +26,103 @@ pub mod bound_stream;
|
||||||
|
|
||||||
// open stream (streams started)
|
// open stream (streams started)
|
||||||
pub struct JabberStream<S> {
|
pub struct JabberStream<S> {
|
||||||
reader: Reader<ReadHalf<S>>,
|
reader: JabberReader<S>,
|
||||||
pub(crate) writer: Writer<WriteHalf<S>>,
|
writer: JabberWriter<S>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> JabberStream<S> {
|
||||||
|
fn split(self) -> (JabberReader<S>, JabberWriter<S>) {
|
||||||
|
let reader = self.reader;
|
||||||
|
let writer = self.writer;
|
||||||
|
(reader, writer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct JabberReader<S>(Reader<ReadHalf<S>>);
|
||||||
|
|
||||||
|
impl<S> JabberReader<S> {
|
||||||
|
// TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner
|
||||||
|
fn new(reader: Reader<ReadHalf<S>>) -> Self {
|
||||||
|
Self(reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unsplit(self, writer: JabberWriter<S>) -> JabberStream<S> {
|
||||||
|
JabberStream {
|
||||||
|
reader: self,
|
||||||
|
writer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_inner(self) -> Reader<ReadHalf<S>> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> JabberReader<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
pub async fn try_close(&mut self) -> Result<()> {
|
||||||
|
self.read_end_tag().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> std::ops::Deref for JabberReader<S> {
|
||||||
|
type Target = Reader<ReadHalf<S>>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> std::ops::DerefMut for JabberReader<S> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct JabberWriter<S>(Writer<WriteHalf<S>>);
|
||||||
|
|
||||||
|
impl<S> JabberWriter<S> {
|
||||||
|
fn new(writer: Writer<WriteHalf<S>>) -> Self {
|
||||||
|
Self(writer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unsplit(self, reader: JabberReader<S>) -> JabberStream<S> {
|
||||||
|
JabberStream {
|
||||||
|
reader,
|
||||||
|
writer: self,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_inner(self) -> Writer<WriteHalf<S>> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> JabberWriter<S>
|
||||||
|
where
|
||||||
|
S: AsyncWrite + Unpin + Send,
|
||||||
|
{
|
||||||
|
pub async fn try_close(&mut self) -> Result<()> {
|
||||||
|
self.write_end().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> std::ops::Deref for JabberWriter<S> {
|
||||||
|
type Target = Writer<WriteHalf<S>>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> std::ops::DerefMut for JabberWriter<S> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> JabberStream<S>
|
impl<S> JabberStream<S>
|
||||||
|
@ -119,8 +214,8 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let writer = self.writer.into_inner();
|
let writer = self.writer.into_inner().into_inner();
|
||||||
let reader = self.reader.into_inner();
|
let reader = self.reader.into_inner().into_inner();
|
||||||
let stream = reader.unsplit(writer);
|
let stream = reader.unsplit(writer);
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
}
|
}
|
||||||
|
@ -223,8 +318,8 @@ where
|
||||||
pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
|
pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
|
||||||
// client to server
|
// client to server
|
||||||
let (reader, writer) = tokio::io::split(connection);
|
let (reader, writer) = tokio::io::split(connection);
|
||||||
let mut reader = Reader::new(reader);
|
let mut reader = JabberReader::new(Reader::new(reader));
|
||||||
let mut writer = Writer::new(writer);
|
let mut writer = JabberWriter::new(Writer::new(writer));
|
||||||
|
|
||||||
// declaration
|
// declaration
|
||||||
writer.write_declaration(XML_VERSION).await?;
|
writer.write_declaration(XML_VERSION).await?;
|
||||||
|
@ -262,7 +357,10 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_inner(self) -> S {
|
pub fn into_inner(self) -> S {
|
||||||
self.reader.into_inner().unsplit(self.writer.into_inner())
|
self.reader
|
||||||
|
.into_inner()
|
||||||
|
.into_inner()
|
||||||
|
.unsplit(self.writer.into_inner().into_inner())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
|
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
|
||||||
|
@ -280,7 +378,11 @@ impl JabberStream<Unencrypted> {
|
||||||
let proceed: Proceed = self.reader.read().await?;
|
let proceed: Proceed = self.reader.read().await?;
|
||||||
debug!("got proceed: {:?}", proceed);
|
debug!("got proceed: {:?}", proceed);
|
||||||
let connector = TlsConnector::new().unwrap();
|
let connector = TlsConnector::new().unwrap();
|
||||||
let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
|
let stream = self
|
||||||
|
.reader
|
||||||
|
.into_inner()
|
||||||
|
.into_inner()
|
||||||
|
.unsplit(self.writer.into_inner().into_inner());
|
||||||
if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
|
if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
|
||||||
.connect(domain.as_ref(), stream)
|
.connect(domain.as_ref(), stream)
|
||||||
.await
|
.await
|
||||||
|
|
|
@ -1,128 +1,82 @@
|
||||||
use std::future::ready;
|
use std::ops::{Deref, DerefMut};
|
||||||
use std::pin::pin;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::task::Poll;
|
|
||||||
|
|
||||||
use futures::ready;
|
|
||||||
use futures::FutureExt;
|
|
||||||
use futures::{sink, stream, Sink, Stream};
|
|
||||||
use peanuts::{Reader, Writer};
|
use peanuts::{Reader, Writer};
|
||||||
use pin_project::pin_project;
|
|
||||||
use stanza::client::Stanza;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tokio::task::JoinHandle;
|
|
||||||
|
|
||||||
use crate::Error;
|
use crate::Error;
|
||||||
|
|
||||||
use super::JabberStream;
|
use super::{JabberReader, JabberStream, JabberWriter};
|
||||||
|
|
||||||
#[pin_project]
|
pub struct BoundJabberStream<S>(JabberStream<S>);
|
||||||
pub struct BoundJabberStream<S>
|
|
||||||
|
impl<S> Deref for BoundJabberStream<S>
|
||||||
where
|
where
|
||||||
S: AsyncWrite + AsyncRead + Unpin + Send,
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||||||
{
|
{
|
||||||
reader: Arc<Mutex<Reader<ReadHalf<S>>>>,
|
type Target = JabberStream<S>;
|
||||||
writer: Arc<Mutex<Writer<WriteHalf<S>>>>,
|
|
||||||
write_handle: Option<JoinHandle<Result<(), Error>>>,
|
fn deref(&self) -> &Self::Target {
|
||||||
read_handle: Option<JoinHandle<Result<Stanza, Error>>>,
|
&self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> BoundJabberStream<S>
|
impl<S> DerefMut for BoundJabberStream<S>
|
||||||
where
|
where
|
||||||
S: AsyncWrite + AsyncRead + Unpin + Send,
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||||||
{
|
{
|
||||||
// TODO: look into biased mutex, to close stream ASAP
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
// TODO: put into connection
|
&mut self.0
|
||||||
// pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
|
|
||||||
// let reader = self.reader.lock().await.into_self();
|
|
||||||
// let writer = self.writer.lock().await.into_self();
|
|
||||||
// // TODO: writer </stream:stream>
|
|
||||||
// return Ok(JabberStream { reader, writer });
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}
|
|
||||||
|
|
||||||
impl<S> Sink<Stanza> for BoundJabberStream<S>
|
|
||||||
where
|
|
||||||
S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
|
|
||||||
{
|
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
fn poll_ready(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
||||||
self.poll_flush(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
|
|
||||||
let this = self.project();
|
|
||||||
if let Some(_write_handle) = this.write_handle {
|
|
||||||
panic!("start_send called without poll_ready")
|
|
||||||
} else {
|
|
||||||
// TODO: switch to buffer of one rather than thread spawning and joining
|
|
||||||
*this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item)));
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
||||||
let this = self.project();
|
|
||||||
Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() {
|
|
||||||
match ready!(join_handle.poll_unpin(cx)) {
|
|
||||||
Ok(state) => {
|
|
||||||
*this.write_handle = None;
|
|
||||||
state
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
*this.write_handle = None;
|
|
||||||
Err(err.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(
|
|
||||||
self: std::pin::Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
||||||
self.poll_flush(cx)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Stream for BoundJabberStream<S>
|
impl<S> BoundJabberStream<S> {
|
||||||
where
|
pub fn split(self) -> (BoundJabberReader<S>, BoundJabberWriter<S>) {
|
||||||
S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
|
let (reader, writer) = self.0.split();
|
||||||
{
|
(BoundJabberReader(reader), BoundJabberWriter(writer))
|
||||||
type Item = Result<Stanza, Error>;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn poll_next(
|
pub struct BoundJabberReader<S>(JabberReader<S>);
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Option<Self::Item>> {
|
|
||||||
let this = self.project();
|
|
||||||
|
|
||||||
loop {
|
impl<S> BoundJabberReader<S> {
|
||||||
if let Some(join_handle) = this.read_handle.as_mut() {
|
pub fn unsplit(self, writer: BoundJabberWriter<S>) -> BoundJabberStream<S> {
|
||||||
let stanza = ready!(join_handle.poll_unpin(cx));
|
BoundJabberStream(self.0.unsplit(writer.0))
|
||||||
if let Ok(item) = stanza {
|
}
|
||||||
*this.read_handle = None;
|
}
|
||||||
return Poll::Ready(Some(item));
|
|
||||||
} else if let Err(err) = stanza {
|
impl<S> std::ops::Deref for BoundJabberReader<S> {
|
||||||
return Poll::Ready(Some(Err(err.into())));
|
type Target = JabberReader<S>;
|
||||||
}
|
|
||||||
} else {
|
fn deref(&self) -> &Self::Target {
|
||||||
*this.read_handle = Some(tokio::spawn(read(this.reader.clone())))
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S> std::ops::DerefMut for BoundJabberReader<S> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct BoundJabberWriter<S>(JabberWriter<S>);
|
||||||
|
|
||||||
|
impl<S> BoundJabberWriter<S> {
|
||||||
|
pub fn unsplit(self, reader: BoundJabberReader<S>) -> BoundJabberStream<S> {
|
||||||
|
BoundJabberStream(self.0.unsplit(reader.0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> std::ops::Deref for BoundJabberWriter<S> {
|
||||||
|
type Target = JabberWriter<S>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> std::ops::DerefMut for BoundJabberWriter<S> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,35 +85,6 @@ where
|
||||||
S: AsyncWrite + AsyncRead + Unpin + Send,
|
S: AsyncWrite + AsyncRead + Unpin + Send,
|
||||||
{
|
{
|
||||||
pub fn to_bound_jabber(self) -> BoundJabberStream<S> {
|
pub fn to_bound_jabber(self) -> BoundJabberStream<S> {
|
||||||
let reader = Arc::new(Mutex::new(self.reader));
|
BoundJabberStream(self)
|
||||||
let writer = Arc::new(Mutex::new(self.writer));
|
|
||||||
BoundJabberStream {
|
|
||||||
writer,
|
|
||||||
reader,
|
|
||||||
write_handle: None,
|
|
||||||
read_handle: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn write<W: AsyncWrite + Unpin + Send>(
|
|
||||||
writer: Arc<Mutex<Writer<WriteHalf<W>>>>,
|
|
||||||
stanza: Stanza,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
{
|
|
||||||
let mut writer = writer.lock().await;
|
|
||||||
writer.write(&stanza).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read<R: AsyncRead + Unpin + Send>(
|
|
||||||
reader: Arc<Mutex<Reader<ReadHalf<R>>>>,
|
|
||||||
) -> Result<Stanza, Error> {
|
|
||||||
let stanza: Result<Stanza, Error>;
|
|
||||||
{
|
|
||||||
let mut reader = reader.lock().await;
|
|
||||||
stanza = reader.read().await.map_err(|e| e.into());
|
|
||||||
}
|
|
||||||
stanza
|
|
||||||
}
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ pub mod presence;
|
||||||
|
|
||||||
pub const XMLNS: &str = "jabber:client";
|
pub const XMLNS: &str = "jabber:client";
|
||||||
|
|
||||||
|
/// TODO: End tag
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Stanza {
|
pub enum Stanza {
|
||||||
Message(Message),
|
Message(Message),
|
||||||
|
|
Loading…
Reference in New Issue