implement client

This commit is contained in:
cel 🌸 2024-12-04 02:09:07 +00:00
parent e0373c0520
commit 4886396044
5 changed files with 213 additions and 86 deletions

View File

@ -1,10 +1,11 @@
use std::sync::Arc;
use std::{pin::pin, sync::Arc, task::Poll};
use futures::{Sink, Stream};
use futures::{Sink, Stream, StreamExt};
use rsasl::config::SASLConfig;
use crate::{
connection::{Tls, Unencrypted},
jid::ParseError,
stanza::{
client::Stanza,
sasl::Mechanisms,
@ -15,14 +16,146 @@ use crate::{
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
connection: JabberState,
connection: ConnectionState,
jid: JID,
password: Arc<SASLConfig>,
server: String,
}
pub enum JabberState {
impl JabberClient {
pub fn new(
jid: impl TryInto<JID, Error = ParseError>,
password: impl ToString,
) -> Result<JabberClient> {
let jid = jid.try_into()?;
let sasl_config = SASLConfig::with_credentials(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.to_string(),
)?;
Ok(JabberClient {
connection: ConnectionState::Disconnected,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
})
}
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
ConnectionState::Disconnected => {
self.connection = ConnectionState::Disconnected
.connect(&mut self.jid, self.password.clone(), &mut self.server)
.await?;
Ok(())
}
ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
ConnectionState::Connected(_jabber_stream) => Ok(()),
}
}
}
impl Stream for JabberClient {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut client = pin!(self);
match &mut client.connection {
ConnectionState::Disconnected => Poll::Pending,
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
}
}
}
pub enum ConnectionState {
Disconnected,
Connecting(Connecting),
Connected(JabberStream<Tls>),
}
impl ConnectionState {
pub async fn connect(
mut self,
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<Self> {
loop {
match self {
ConnectionState::Disconnected => {
self = ConnectionState::Connecting(Connecting::start(&server).await?);
}
ConnectionState::Connecting(connecting) => match connecting {
Connecting::InsecureConnectionEstablised(tcp_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
))
}
Connecting::InsecureStreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => {
self =
ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
}
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
Connecting::StartTls(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.starttls(&server).await?,
))
}
Connecting::ConnectionEstablished(tls_stream) => {
self = ConnectionState::Connecting(Connecting::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
))
}
Connecting::StreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::GotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::GotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
self = ConnectionState::Connecting(Connecting::Sasl(
mechanisms,
jabber_stream,
))
}
Feature::Bind => {
self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
}
Feature::Unknown => return Err(Error::Unsupported),
}
}
Connecting::Sasl(mechanisms, jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth.clone()).await?,
))
}
Connecting::Bind(jabber_stream) => {
self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
}
},
connected => return Ok(connected),
}
}
}
}
pub enum Connecting {
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
@ -32,67 +165,15 @@ pub enum JabberState {
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
// when it's bound, can stream stanzas and sink stanzas
Bound(JabberStream<Tls>),
}
impl JabberState {
pub async fn advance_state(
self,
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<JabberState> {
match self {
JabberState::Disconnected => match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => {
Ok(JabberState::ConnectionEstablished(tls_stream))
}
impl Connecting {
pub async fn start(server: &str) -> Result<Self> {
match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
Connection::Unencrypted(tcp_stream) => {
Ok(JabberState::InsecureConnectionEstablised(tcp_stream))
Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
}
},
JabberState::InsecureConnectionEstablised(tcp_stream) => Ok({
JabberState::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
)
}),
JabberState::InsecureStreamStarted(jabber_stream) => Ok(
JabberState::InsecureGotFeatures(jabber_stream.get_features().await?),
),
JabberState::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => Ok(JabberState::StartTls(jabber_stream)),
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
JabberState::StartTls(jabber_stream) => Ok(JabberState::ConnectionEstablished(
jabber_stream.starttls(server).await?,
)),
JabberState::ConnectionEstablished(tls_stream) => Ok(JabberState::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
)),
JabberState::StreamStarted(jabber_stream) => Ok(JabberState::GotFeatures(
jabber_stream.get_features().await?,
)),
JabberState::GotFeatures((features, jabber_stream)) => match features.negotiate()? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
return Ok(JabberState::Sasl(mechanisms, jabber_stream))
}
Feature::Bind => return Ok(JabberState::Bind(jabber_stream)),
Feature::Unknown => return Err(Error::Unsupported),
},
JabberState::Sasl(mechanisms, jabber_stream) => {
return Ok(JabberState::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth).await?,
))
}
JabberState::Bind(jabber_stream) => {
Ok(JabberState::Bound(jabber_stream.bind(jid).await?))
}
JabberState::Bound(jabber_stream) => Ok(JabberState::Bound(jabber_stream)),
}
}
}
@ -126,7 +207,7 @@ impl Features {
}
}
pub enum InsecureJabberConnection {
pub enum InsecureConnecting {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
@ -136,17 +217,6 @@ pub enum InsecureJabberConnection {
Bound(JabberStream<Tls>),
}
impl Stream for JabberClient {
type Item = Stanza;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
todo!()
}
}
impl Sink<Stanza> for JabberClient {
type Error = Error;
@ -178,3 +248,19 @@ impl Sink<Stanza> for JabberClient {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::JabberClient;
use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)]
async fn login() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await
}
}

View File

@ -13,8 +13,11 @@ pub enum Error {
TlsRequired,
AlreadyTls,
Unsupported,
NoLocalpart,
AlreadyConnecting,
UnexpectedElement(peanuts::Element),
XML(peanuts::Error),
Deserialization(peanuts::DeserializeError),
SASL(SASLError),
JID(ParseError),
Authentication(Failure),
@ -34,6 +37,12 @@ impl From<rsasl::prelude::SASLError> for Error {
}
}
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)
}
}
impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e))

View File

@ -1,8 +1,10 @@
use std::pin::pin;
use std::str::{self, FromStr};
use std::sync::Arc;
use async_recursion::async_recursion;
use peanuts::element::IntoElement;
use futures::StreamExt;
use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
@ -13,6 +15,7 @@ use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType};
use crate::stanza::client::iq::{Iq, IqType, Query};
use crate::stanza::client::Stanza;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream};
@ -26,6 +29,22 @@ pub struct JabberStream<S> {
writer: Writer<WriteHalf<S>>,
}
impl<S: AsyncRead> futures::Stream for JabberStream<S> {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
pin!(self).reader.poll_next_unpin(cx).map(|content| {
content.map(|content| -> Result<Stanza> {
let stanza = content.map(|content| Stanza::from_content(content))?;
Ok(stanza?)
})
})
}
}
impl<S> JabberStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,

View File

@ -29,8 +29,8 @@ pub async fn login<J: AsRef<str>, P: AsRef<str>>(jid: J, password: P) -> Result<
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test_login() {
crate::login("test@blos.sm/clown", "slayed").await.unwrap();
}
// #[tokio::test]
// async fn test_login() {
// crate::login("test@blos.sm/clown", "slayed").await.unwrap();
// }
}

View File

@ -1,7 +1,7 @@
use iq::Iq;
use message::Message;
use peanuts::{
element::{FromElement, IntoElement},
element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement},
DeserializeError,
};
use presence::Presence;
@ -20,6 +20,18 @@ pub enum Stanza {
Presence(Presence),
Iq(Iq),
Error(StreamError),
OtherContent(Content),
}
impl FromContent for Stanza {
fn from_content(content: Content) -> peanuts::element::DeserializeResult<Self> {
match content {
Content::Element(element) => Ok(Stanza::from_element(element)?),
Content::Text(_) => Ok(Stanza::OtherContent(content)),
Content::PI => Ok(Stanza::OtherContent(content)),
Content::Comment(_) => Ok(Stanza::OtherContent(content)),
}
}
}
impl FromElement for Stanza {
@ -36,13 +48,14 @@ impl FromElement for Stanza {
}
}
impl IntoElement for Stanza {
fn builder(&self) -> peanuts::element::ElementBuilder {
impl IntoContent for Stanza {
fn builder(&self) -> peanuts::element::ContentBuilder {
match self {
Stanza::Message(message) => message.builder(),
Stanza::Presence(presence) => presence.builder(),
Stanza::Iq(iq) => iq.builder(),
Stanza::Error(error) => error.builder(),
Stanza::Message(message) => <Message as IntoContent>::builder(message),
Stanza::Presence(presence) => <Presence as IntoContent>::builder(presence),
Stanza::Iq(iq) => <Iq as IntoContent>::builder(iq),
Stanza::Error(error) => <StreamError as IntoContent>::builder(error),
Stanza::OtherContent(_content) => ContentBuilder::Comment("other-content".to_string()),
}
}
}