implement client

This commit is contained in:
cel 🌸 2024-12-04 02:09:07 +00:00
parent e0373c0520
commit 4886396044
5 changed files with 213 additions and 86 deletions

View File

@ -1,10 +1,11 @@
use std::sync::Arc; use std::{pin::pin, sync::Arc, task::Poll};
use futures::{Sink, Stream}; use futures::{Sink, Stream, StreamExt};
use rsasl::config::SASLConfig; use rsasl::config::SASLConfig;
use crate::{ use crate::{
connection::{Tls, Unencrypted}, connection::{Tls, Unencrypted},
jid::ParseError,
stanza::{ stanza::{
client::Stanza, client::Stanza,
sasl::Mechanisms, sasl::Mechanisms,
@ -15,14 +16,146 @@ use crate::{
// feed it client stanzas, receive client stanzas // feed it client stanzas, receive client stanzas
pub struct JabberClient { pub struct JabberClient {
connection: JabberState, connection: ConnectionState,
jid: JID, jid: JID,
password: Arc<SASLConfig>, password: Arc<SASLConfig>,
server: String, server: String,
} }
pub enum JabberState { impl JabberClient {
pub fn new(
jid: impl TryInto<JID, Error = ParseError>,
password: impl ToString,
) -> Result<JabberClient> {
let jid = jid.try_into()?;
let sasl_config = SASLConfig::with_credentials(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.to_string(),
)?;
Ok(JabberClient {
connection: ConnectionState::Disconnected,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
})
}
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
ConnectionState::Disconnected => {
self.connection = ConnectionState::Disconnected
.connect(&mut self.jid, self.password.clone(), &mut self.server)
.await?;
Ok(())
}
ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
ConnectionState::Connected(_jabber_stream) => Ok(()),
}
}
}
impl Stream for JabberClient {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut client = pin!(self);
match &mut client.connection {
ConnectionState::Disconnected => Poll::Pending,
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
}
}
}
pub enum ConnectionState {
Disconnected, Disconnected,
Connecting(Connecting),
Connected(JabberStream<Tls>),
}
impl ConnectionState {
pub async fn connect(
mut self,
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<Self> {
loop {
match self {
ConnectionState::Disconnected => {
self = ConnectionState::Connecting(Connecting::start(&server).await?);
}
ConnectionState::Connecting(connecting) => match connecting {
Connecting::InsecureConnectionEstablised(tcp_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
))
}
Connecting::InsecureStreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => {
self =
ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
}
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
Connecting::StartTls(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.starttls(&server).await?,
))
}
Connecting::ConnectionEstablished(tls_stream) => {
self = ConnectionState::Connecting(Connecting::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
))
}
Connecting::StreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::GotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::GotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
self = ConnectionState::Connecting(Connecting::Sasl(
mechanisms,
jabber_stream,
))
}
Feature::Bind => {
self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
}
Feature::Unknown => return Err(Error::Unsupported),
}
}
Connecting::Sasl(mechanisms, jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth.clone()).await?,
))
}
Connecting::Bind(jabber_stream) => {
self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
}
},
connected => return Ok(connected),
}
}
}
}
pub enum Connecting {
InsecureConnectionEstablised(Unencrypted), InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>), InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)), InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
@ -32,67 +165,15 @@ pub enum JabberState {
GotFeatures((Features, JabberStream<Tls>)), GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>), Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>), Bind(JabberStream<Tls>),
// when it's bound, can stream stanzas and sink stanzas
Bound(JabberStream<Tls>),
} }
impl JabberState { impl Connecting {
pub async fn advance_state( pub async fn start(server: &str) -> Result<Self> {
self, match Connection::connect(server).await? {
jid: &mut JID, Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
auth: Arc<SASLConfig>, Connection::Unencrypted(tcp_stream) => {
server: &mut String, Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
) -> Result<JabberState> {
match self {
JabberState::Disconnected => match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => {
Ok(JabberState::ConnectionEstablished(tls_stream))
}
Connection::Unencrypted(tcp_stream) => {
Ok(JabberState::InsecureConnectionEstablised(tcp_stream))
}
},
JabberState::InsecureConnectionEstablised(tcp_stream) => Ok({
JabberState::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
)
}),
JabberState::InsecureStreamStarted(jabber_stream) => Ok(
JabberState::InsecureGotFeatures(jabber_stream.get_features().await?),
),
JabberState::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => Ok(JabberState::StartTls(jabber_stream)),
// TODO: better error
_ => return Err(Error::TlsRequired),
}
} }
JabberState::StartTls(jabber_stream) => Ok(JabberState::ConnectionEstablished(
jabber_stream.starttls(server).await?,
)),
JabberState::ConnectionEstablished(tls_stream) => Ok(JabberState::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
)),
JabberState::StreamStarted(jabber_stream) => Ok(JabberState::GotFeatures(
jabber_stream.get_features().await?,
)),
JabberState::GotFeatures((features, jabber_stream)) => match features.negotiate()? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
return Ok(JabberState::Sasl(mechanisms, jabber_stream))
}
Feature::Bind => return Ok(JabberState::Bind(jabber_stream)),
Feature::Unknown => return Err(Error::Unsupported),
},
JabberState::Sasl(mechanisms, jabber_stream) => {
return Ok(JabberState::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth).await?,
))
}
JabberState::Bind(jabber_stream) => {
Ok(JabberState::Bound(jabber_stream.bind(jid).await?))
}
JabberState::Bound(jabber_stream) => Ok(JabberState::Bound(jabber_stream)),
} }
} }
} }
@ -126,7 +207,7 @@ impl Features {
} }
} }
pub enum InsecureJabberConnection { pub enum InsecureConnecting {
Disconnected, Disconnected,
ConnectionEstablished(Connection), ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>), PreStarttls(JabberStream<Unencrypted>),
@ -136,17 +217,6 @@ pub enum InsecureJabberConnection {
Bound(JabberStream<Tls>), Bound(JabberStream<Tls>),
} }
impl Stream for JabberClient {
type Item = Stanza;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
todo!()
}
}
impl Sink<Stanza> for JabberClient { impl Sink<Stanza> for JabberClient {
type Error = Error; type Error = Error;
@ -178,3 +248,19 @@ impl Sink<Stanza> for JabberClient {
todo!() todo!()
} }
} }
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::JabberClient;
use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)]
async fn login() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await
}
}

View File

@ -13,8 +13,11 @@ pub enum Error {
TlsRequired, TlsRequired,
AlreadyTls, AlreadyTls,
Unsupported, Unsupported,
NoLocalpart,
AlreadyConnecting,
UnexpectedElement(peanuts::Element), UnexpectedElement(peanuts::Element),
XML(peanuts::Error), XML(peanuts::Error),
Deserialization(peanuts::DeserializeError),
SASL(SASLError), SASL(SASLError),
JID(ParseError), JID(ParseError),
Authentication(Failure), Authentication(Failure),
@ -34,6 +37,12 @@ impl From<rsasl::prelude::SASLError> for Error {
} }
} }
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)
}
}
impl From<MechanismNameError> for Error { impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self { fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e)) Self::SASL(SASLError::MechanismName(e))

View File

@ -1,8 +1,10 @@
use std::pin::pin;
use std::str::{self, FromStr}; use std::str::{self, FromStr};
use std::sync::Arc; use std::sync::Arc;
use async_recursion::async_recursion; use async_recursion::async_recursion;
use peanuts::element::IntoElement; use futures::StreamExt;
use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer}; use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
@ -13,6 +15,7 @@ use crate::connection::{Tls, Unencrypted};
use crate::error::Error; use crate::error::Error;
use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType}; use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType};
use crate::stanza::client::iq::{Iq, IqType, Query}; use crate::stanza::client::iq::{Iq, IqType, Query};
use crate::stanza::client::Stanza;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls}; use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream}; use crate::stanza::stream::{Feature, Features, Stream};
@ -26,6 +29,22 @@ pub struct JabberStream<S> {
writer: Writer<WriteHalf<S>>, writer: Writer<WriteHalf<S>>,
} }
impl<S: AsyncRead> futures::Stream for JabberStream<S> {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
pin!(self).reader.poll_next_unpin(cx).map(|content| {
content.map(|content| -> Result<Stanza> {
let stanza = content.map(|content| Stanza::from_content(content))?;
Ok(stanza?)
})
})
}
}
impl<S> JabberStream<S> impl<S> JabberStream<S>
where where
S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,

View File

@ -29,8 +29,8 @@ pub async fn login<J: AsRef<str>, P: AsRef<str>>(jid: J, password: P) -> Result<
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[tokio::test] // #[tokio::test]
async fn test_login() { // async fn test_login() {
crate::login("test@blos.sm/clown", "slayed").await.unwrap(); // crate::login("test@blos.sm/clown", "slayed").await.unwrap();
} // }
} }

View File

@ -1,7 +1,7 @@
use iq::Iq; use iq::Iq;
use message::Message; use message::Message;
use peanuts::{ use peanuts::{
element::{FromElement, IntoElement}, element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement},
DeserializeError, DeserializeError,
}; };
use presence::Presence; use presence::Presence;
@ -20,6 +20,18 @@ pub enum Stanza {
Presence(Presence), Presence(Presence),
Iq(Iq), Iq(Iq),
Error(StreamError), Error(StreamError),
OtherContent(Content),
}
impl FromContent for Stanza {
fn from_content(content: Content) -> peanuts::element::DeserializeResult<Self> {
match content {
Content::Element(element) => Ok(Stanza::from_element(element)?),
Content::Text(_) => Ok(Stanza::OtherContent(content)),
Content::PI => Ok(Stanza::OtherContent(content)),
Content::Comment(_) => Ok(Stanza::OtherContent(content)),
}
}
} }
impl FromElement for Stanza { impl FromElement for Stanza {
@ -36,13 +48,14 @@ impl FromElement for Stanza {
} }
} }
impl IntoElement for Stanza { impl IntoContent for Stanza {
fn builder(&self) -> peanuts::element::ElementBuilder { fn builder(&self) -> peanuts::element::ContentBuilder {
match self { match self {
Stanza::Message(message) => message.builder(), Stanza::Message(message) => <Message as IntoContent>::builder(message),
Stanza::Presence(presence) => presence.builder(), Stanza::Presence(presence) => <Presence as IntoContent>::builder(presence),
Stanza::Iq(iq) => iq.builder(), Stanza::Iq(iq) => <Iq as IntoContent>::builder(iq),
Stanza::Error(error) => error.builder(), Stanza::Error(error) => <StreamError as IntoContent>::builder(error),
Stanza::OtherContent(_content) => ContentBuilder::Comment("other-content".to_string()),
} }
} }
} }