Compare commits

..

2 Commits

Author SHA1 Message Date
cel 🌸 4886396044 implement client 2024-12-04 02:09:07 +00:00
cel 🌸 e0373c0520 WIP: connecting fsm 2024-12-03 23:57:04 +00:00
7 changed files with 532 additions and 287 deletions

View File

@ -19,6 +19,7 @@ tracing = "0.1.40"
trust-dns-resolver = "0.22.0" trust-dns-resolver = "0.22.0"
try_map = "0.3.1" try_map = "0.3.1"
peanuts = { version = "0.1.0", path = "../peanuts" } peanuts = { version = "0.1.0", path = "../peanuts" }
futures = "0.3.31"
[dev-dependencies] [dev-dependencies]
test-log = { version = "0.2", features = ["trace"] } test-log = { version = "0.2", features = ["trace"] }

266
src/client.rs Normal file
View File

@ -0,0 +1,266 @@
use std::{pin::pin, sync::Arc, task::Poll};
use futures::{Sink, Stream, StreamExt};
use rsasl::config::SASLConfig;
use crate::{
connection::{Tls, Unencrypted},
jid::ParseError,
stanza::{
client::Stanza,
sasl::Mechanisms,
stream::{Feature, Features},
},
Connection, Error, JabberStream, Result, JID,
};
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
connection: ConnectionState,
jid: JID,
password: Arc<SASLConfig>,
server: String,
}
impl JabberClient {
pub fn new(
jid: impl TryInto<JID, Error = ParseError>,
password: impl ToString,
) -> Result<JabberClient> {
let jid = jid.try_into()?;
let sasl_config = SASLConfig::with_credentials(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.to_string(),
)?;
Ok(JabberClient {
connection: ConnectionState::Disconnected,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
})
}
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
ConnectionState::Disconnected => {
self.connection = ConnectionState::Disconnected
.connect(&mut self.jid, self.password.clone(), &mut self.server)
.await?;
Ok(())
}
ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
ConnectionState::Connected(_jabber_stream) => Ok(()),
}
}
}
impl Stream for JabberClient {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut client = pin!(self);
match &mut client.connection {
ConnectionState::Disconnected => Poll::Pending,
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
}
}
}
pub enum ConnectionState {
Disconnected,
Connecting(Connecting),
Connected(JabberStream<Tls>),
}
impl ConnectionState {
pub async fn connect(
mut self,
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<Self> {
loop {
match self {
ConnectionState::Disconnected => {
self = ConnectionState::Connecting(Connecting::start(&server).await?);
}
ConnectionState::Connecting(connecting) => match connecting {
Connecting::InsecureConnectionEstablised(tcp_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
))
}
Connecting::InsecureStreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => {
self =
ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
}
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
Connecting::StartTls(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.starttls(&server).await?,
))
}
Connecting::ConnectionEstablished(tls_stream) => {
self = ConnectionState::Connecting(Connecting::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
))
}
Connecting::StreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::GotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::GotFeatures((features, jabber_stream)) => {
match features.negotiate()? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
self = ConnectionState::Connecting(Connecting::Sasl(
mechanisms,
jabber_stream,
))
}
Feature::Bind => {
self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
}
Feature::Unknown => return Err(Error::Unsupported),
}
}
Connecting::Sasl(mechanisms, jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth.clone()).await?,
))
}
Connecting::Bind(jabber_stream) => {
self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
}
},
connected => return Ok(connected),
}
}
}
}
pub enum Connecting {
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
StartTls(JabberStream<Unencrypted>),
ConnectionEstablished(Tls),
StreamStarted(JabberStream<Tls>),
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
}
impl Connecting {
pub async fn start(server: &str) -> Result<Self> {
match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
Connection::Unencrypted(tcp_stream) => {
Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
}
}
}
}
impl Features {
pub fn negotiate(self) -> Result<Feature> {
if let Some(Feature::StartTls(s)) = self
.features
.iter()
.find(|feature| matches!(feature, Feature::StartTls(_s)))
{
// TODO: avoid clone
return Ok(Feature::StartTls(s.clone()));
} else if let Some(Feature::Sasl(mechanisms)) = self
.features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_)))
{
// TODO: avoid clone
return Ok(Feature::Sasl(mechanisms.clone()));
} else if let Some(Feature::Bind) = self
.features
.into_iter()
.find(|feature| matches!(feature, Feature::Bind))
{
Ok(Feature::Bind)
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
}
pub enum InsecureConnecting {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
PreAuthenticated(JabberStream<Tls>),
Authenticated(Tls),
PreBound(JabberStream<Tls>),
Bound(JabberStream<Tls>),
}
impl Sink<Stanza> for JabberClient {
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: Stanza,
) -> std::result::Result<(), Self::Error> {
todo!()
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::JabberClient;
use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)]
async fn login() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await
}
}

View File

@ -10,7 +10,6 @@ use tokio_native_tls::native_tls::TlsConnector;
use tokio_native_tls::TlsStream; use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace}; use tracing::{debug, info, instrument, trace};
use crate::Jabber;
use crate::Result; use crate::Result;
use crate::{Error, JID}; use crate::{Error, JID};
@ -19,69 +18,51 @@ pub type Unencrypted = TcpStream;
#[derive(Debug)] #[derive(Debug)]
pub enum Connection { pub enum Connection {
Encrypted(Jabber<Tls>), Encrypted(Tls),
Unencrypted(Jabber<Unencrypted>), Unencrypted(Unencrypted),
} }
impl Connection { impl Connection {
#[instrument] // #[instrument]
/// stream not started /// stream not started
pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { // pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
match self { // match self {
Connection::Encrypted(j) => Ok(j), // Connection::Encrypted(j) => Ok(j),
Connection::Unencrypted(mut j) => { // Connection::Unencrypted(mut j) => {
j.start_stream().await?; // j.start_stream().await?;
info!("upgrading connection to tls"); // info!("upgrading connection to tls");
j.get_features().await?; // j.get_features().await?;
let j = j.starttls().await?; // let j = j.starttls().await?;
Ok(j) // Ok(j)
} // }
} // }
} // }
pub async fn connect_user(jid: impl AsRef<str>, password: String) -> Result<Self> { pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> {
let jid: JID = JID::from_str(jid.as_ref())?; let jid: JID = JID::from_str(jid.as_ref())?;
let server = jid.domainpart.clone(); let server = jid.domainpart.clone();
let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; Self::connect(&server).await
println!("auth: {:?}", auth);
Self::connect(&server, Some(jid), Some(auth)).await
} }
#[instrument] #[instrument]
pub async fn connect( pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> {
server: &str, info!("connecting to {}", server.as_ref());
jid: Option<JID>, let sockets = Self::get_sockets(server.as_ref()).await;
auth: Option<Arc<SASLConfig>>,
) -> Result<Self> {
info!("connecting to {}", server);
let sockets = Self::get_sockets(&server).await;
debug!("discovered sockets: {:?}", sockets); debug!("discovered sockets: {:?}", sockets);
for (socket_addr, tls) in sockets { for (socket_addr, tls) in sockets {
match tls { match tls {
true => { true => {
if let Ok(connection) = Self::connect_tls(socket_addr, &server).await { if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await {
info!("connected via encrypted stream to {}", socket_addr); info!("connected via encrypted stream to {}", socket_addr);
let (readhalf, writehalf) = tokio::io::split(connection); // let (readhalf, writehalf) = tokio::io::split(connection);
return Ok(Self::Encrypted(Jabber::new( return Ok(Self::Encrypted(connection));
readhalf,
writehalf,
jid,
auth,
server.to_owned(),
)));
} }
} }
false => { false => {
if let Ok(connection) = Self::connect_unencrypted(socket_addr).await { if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
info!("connected via unencrypted stream to {}", socket_addr); info!("connected via unencrypted stream to {}", socket_addr);
let (readhalf, writehalf) = tokio::io::split(connection); // let (readhalf, writehalf) = tokio::io::split(connection);
return Ok(Self::Unencrypted(Jabber::new( return Ok(Self::Unencrypted(connection));
readhalf,
writehalf,
jid,
auth,
server.to_owned(),
)));
} }
} }
} }
@ -188,16 +169,16 @@ mod tests {
#[test(tokio::test)] #[test(tokio::test)]
async fn connect() { async fn connect() {
Connection::connect("blos.sm", None, None).await.unwrap(); Connection::connect("blos.sm").await.unwrap();
} }
#[test(tokio::test)] // #[test(tokio::test)]
async fn test_tls() { // async fn test_tls() {
Connection::connect("blos.sm", None, None) // Connection::connect("blos.sm", None, None)
.await // .await
.unwrap() // .unwrap()
.ensure_tls() // .ensure_tls()
.await // .await
.unwrap(); // .unwrap();
} // }
} }

View File

@ -8,24 +8,16 @@ use crate::{jid::ParseError, stanza::sasl::Failure};
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
Connection, Connection,
BadStream,
StartTlsUnavailable,
TlsNegotiation,
Utf8Decode, Utf8Decode,
NoFeatures,
UnknownNamespace,
UnknownAttribute,
NoID,
NoType,
IDMismatch,
BindError,
ParseError,
Negotiation, Negotiation,
TlsRequired, TlsRequired,
UnexpectedEnd, AlreadyTls,
Unsupported,
NoLocalpart,
AlreadyConnecting,
UnexpectedElement(peanuts::Element), UnexpectedElement(peanuts::Element),
UnexpectedText,
XML(peanuts::Error), XML(peanuts::Error),
Deserialization(peanuts::DeserializeError),
SASL(SASLError), SASL(SASLError),
JID(ParseError), JID(ParseError),
Authentication(Failure), Authentication(Failure),
@ -37,8 +29,6 @@ pub enum Error {
pub enum SASLError { pub enum SASLError {
SASL(rsasl::prelude::SASLError), SASL(rsasl::prelude::SASLError),
MechanismName(MechanismNameError), MechanismName(MechanismNameError),
NoChallenge,
NoSuccess,
} }
impl From<rsasl::prelude::SASLError> for Error { impl From<rsasl::prelude::SASLError> for Error {
@ -47,6 +37,12 @@ impl From<rsasl::prelude::SASLError> for Error {
} }
} }
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)
}
}
impl From<MechanismNameError> for Error { impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self { fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e)) Self::SASL(SASLError::MechanismName(e))

View File

@ -1,8 +1,10 @@
use std::str; use std::pin::pin;
use std::str::{self, FromStr};
use std::sync::Arc; use std::sync::Arc;
use async_recursion::async_recursion; use async_recursion::async_recursion;
use peanuts::element::IntoElement; use futures::StreamExt;
use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer}; use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
@ -13,6 +15,7 @@ use crate::connection::{Tls, Unencrypted};
use crate::error::Error; use crate::error::Error;
use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType}; use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType};
use crate::stanza::client::iq::{Iq, IqType, Query}; use crate::stanza::client::iq::{Iq, IqType, Query};
use crate::stanza::client::Stanza;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls}; use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream}; use crate::stanza::stream::{Feature, Features, Stream};
@ -20,47 +23,34 @@ use crate::stanza::XML_VERSION;
use crate::JID; use crate::JID;
use crate::{Connection, Result}; use crate::{Connection, Result};
pub struct Jabber<S> { // open stream (streams started)
pub struct JabberStream<S> {
reader: Reader<ReadHalf<S>>, reader: Reader<ReadHalf<S>>,
writer: Writer<WriteHalf<S>>, writer: Writer<WriteHalf<S>>,
jid: Option<JID>,
auth: Option<Arc<SASLConfig>>,
server: String,
} }
impl<S> Jabber<S> impl<S: AsyncRead> futures::Stream for JabberStream<S> {
where type Item = Result<Stanza>;
S: AsyncRead + AsyncWrite + Unpin,
{ fn poll_next(
pub fn new( self: std::pin::Pin<&mut Self>,
reader: ReadHalf<S>, cx: &mut std::task::Context<'_>,
writer: WriteHalf<S>, ) -> std::task::Poll<Option<Self::Item>> {
jid: Option<JID>, pin!(self).reader.poll_next_unpin(cx).map(|content| {
auth: Option<Arc<SASLConfig>>, content.map(|content| -> Result<Stanza> {
server: String, let stanza = content.map(|content| Stanza::from_content(content))?;
) -> Self { Ok(stanza?)
let reader = Reader::new(reader); })
let writer = Writer::new(writer); })
Self {
reader,
writer,
jid,
auth,
server,
}
} }
} }
impl<S> Jabber<S> impl<S> JabberStream<S>
where where
S: AsyncRead + AsyncWrite + Unpin + Send, S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
Jabber<S>: std::fmt::Debug, JabberStream<S>: std::fmt::Debug,
{ {
pub async fn sasl( pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> {
&mut self,
mechanisms: Mechanisms,
sasl_config: Arc<SASLConfig>,
) -> Result<()> {
let sasl = SASLClient::new(sasl_config); let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new(); let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms { for mechanism in &mechanisms.mechanisms {
@ -143,12 +133,15 @@ where
} }
} }
} }
Ok(()) let writer = self.writer.into_inner();
let reader = self.reader.into_inner();
let stream = reader.unsplit(writer);
Ok(stream)
} }
pub async fn bind(&mut self) -> Result<()> { pub async fn bind(mut self, jid: &mut JID) -> Result<Self> {
let iq_id = nanoid::nanoid!(); let iq_id = nanoid::nanoid!();
if let Some(resource) = self.jid.clone().unwrap().resourcepart { if let Some(resource) = &jid.resourcepart {
let iq = Iq { let iq = Iq {
from: None, from: None,
id: iq_id.clone(), id: iq_id.clone(),
@ -156,7 +149,7 @@ where
r#type: IqType::Set, r#type: IqType::Set,
lang: None, lang: None,
query: Some(Query::Bind(Bind { query: Some(Query::Bind(Bind {
r#type: Some(BindType::Resource(ResourceType(resource))), r#type: Some(BindType::Resource(ResourceType(resource.to_string()))),
})), })),
errors: Vec::new(), errors: Vec::new(),
}; };
@ -171,12 +164,12 @@ where
lang: _, lang: _,
query: query:
Some(Query::Bind(Bind { Some(Query::Bind(Bind {
r#type: Some(BindType::Jid(FullJidType(jid))), r#type: Some(BindType::Jid(FullJidType(new_jid))),
})), })),
errors: _, errors: _,
} if id == iq_id => { } if id == iq_id => {
self.jid = Some(jid); *jid = new_jid;
return Ok(()); return Ok(self);
} }
Iq { Iq {
from: _, from: _,
@ -214,12 +207,12 @@ where
lang: _, lang: _,
query: query:
Some(Query::Bind(Bind { Some(Query::Bind(Bind {
r#type: Some(BindType::Jid(FullJidType(jid))), r#type: Some(BindType::Jid(FullJidType(new_jid))),
})), })),
errors: _, errors: _,
} if id == iq_id => { } if id == iq_id => {
self.jid = Some(jid); *jid = new_jid;
return Ok(()); return Ok(self);
} }
Iq { Iq {
from: _, from: _,
@ -240,39 +233,44 @@ where
} }
#[instrument] #[instrument]
pub async fn start_stream(&mut self) -> Result<()> { pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
// client to server // client to server
let (reader, writer) = tokio::io::split(connection);
let mut reader = Reader::new(reader);
let mut writer = Writer::new(writer);
// declaration // declaration
self.writer.write_declaration(XML_VERSION).await?; writer.write_declaration(XML_VERSION).await?;
// opening stream element // opening stream element
let server = self.server.clone().try_into()?; let stream = Stream::new_client(
let stream = Stream::new_client(None, server, None, "en".to_string()); None,
self.writer.write_start(&stream).await?; JID::from_str(server.as_ref())?,
None,
"en".to_string(),
);
writer.write_start(&stream).await?;
// server to client // server to client
// may or may not send a declaration // may or may not send a declaration
let _decl = self.reader.read_prolog().await?; let _decl = reader.read_prolog().await?;
// receive stream element and validate // receive stream element and validate
let text = str::from_utf8(self.reader.buffer.data()).unwrap(); let stream: Stream = reader.read_start().await?;
debug!("data: {}", text);
let stream: Stream = self.reader.read_start().await?;
debug!("got stream: {:?}", stream); debug!("got stream: {:?}", stream);
if let Some(from) = stream.from { if let Some(from) = stream.from {
self.server = from.to_string() *server = from.to_string();
} }
Ok(()) Ok(Self { reader, writer })
} }
pub async fn get_features(&mut self) -> Result<Features> { pub async fn get_features(mut self) -> Result<(Features, Self)> {
debug!("getting features"); debug!("getting features");
let features: Features = self.reader.read().await?; let features: Features = self.reader.read().await?;
debug!("got features: {:?}", features); debug!("got features: {:?}", features);
Ok(features) Ok((features, self))
} }
pub fn into_inner(self) -> S { pub fn into_inner(self) -> S {
@ -280,89 +278,89 @@ where
} }
} }
impl Jabber<Unencrypted> { impl JabberStream<Unencrypted> {
pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> { // pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(
self.start_stream().await?; // mut self,
// TODO: timeout // features: Features,
let features = self.get_features().await?.features; // ) -> Result<Feature> {
if let Some(Feature::StartTls(_)) = features // // TODO: timeout
.iter() // if let Some(Feature::StartTls(_)) = features
.find(|feature| matches!(feature, Feature::StartTls(_s))) // .features
{ // .iter()
let jabber = self.starttls().await?; // .find(|feature| matches!(feature, Feature::StartTls(_s)))
let jabber = jabber.negotiate().await?; // {
return Ok(jabber); // return Ok(self);
} else { // } else {
// TODO: better error // // TODO: better error
return Err(Error::TlsRequired); // return Err(Error::TlsRequired);
} // }
// }
// #[async_recursion]
// pub async fn negotiate_tls_optional(mut self) -> Result<Connection> {
// self.start_stream().await?;
// // TODO: timeout
// let features = self.get_features().await?.features;
// if let Some(Feature::StartTls(_)) = features
// .iter()
// .find(|feature| matches!(feature, Feature::StartTls(_s)))
// {
// let jabber = self.starttls().await?;
// let jabber = jabber.negotiate().await?;
// return Ok(Connection::Encrypted(jabber));
// } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
// self.auth.clone(),
// features
// .iter()
// .find(|feature| matches!(feature, Feature::Sasl(_))),
// ) {
// self.sasl(mechanisms.clone(), sasl_config).await?;
// let jabber = self.negotiate_tls_optional().await?;
// Ok(jabber)
// } else if let Some(Feature::Bind) = features
// .iter()
// .find(|feature| matches!(feature, Feature::Bind))
// {
// self.bind().await?;
// Ok(Connection::Unencrypted(self))
// } else {
// // TODO: better error
// return Err(Error::Negotiation);
// }
// }
} }
#[async_recursion] impl JabberStream<Tls> {
pub async fn negotiate_tls_optional(mut self) -> Result<Connection> { // #[async_recursion]
self.start_stream().await?; // pub async fn negotiate(mut self) -> Result<JabberStream<Tls>> {
// TODO: timeout // self.start_stream().await?;
let features = self.get_features().await?.features; // let features = self.get_features().await?.features;
if let Some(Feature::StartTls(_)) = features
.iter() // if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
.find(|feature| matches!(feature, Feature::StartTls(_s))) // self.auth.clone(),
{ // features
let jabber = self.starttls().await?; // .iter()
let jabber = jabber.negotiate().await?; // .find(|feature| matches!(feature, Feature::Sasl(_))),
return Ok(Connection::Encrypted(jabber)); // ) {
} else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( // // TODO: avoid clone
self.auth.clone(), // self.sasl(mechanisms.clone(), sasl_config).await?;
features // let jabber = self.negotiate().await?;
.iter() // Ok(jabber)
.find(|feature| matches!(feature, Feature::Sasl(_))), // } else if let Some(Feature::Bind) = features
) { // .iter()
self.sasl(mechanisms.clone(), sasl_config).await?; // .find(|feature| matches!(feature, Feature::Bind))
let jabber = self.negotiate_tls_optional().await?; // {
Ok(jabber) // self.bind().await?;
} else if let Some(Feature::Bind) = features // Ok(self)
.iter() // } else {
.find(|feature| matches!(feature, Feature::Bind)) // // TODO: better error
{ // return Err(Error::Negotiation);
self.bind().await?; // }
Ok(Connection::Unencrypted(self)) // }
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
} }
impl Jabber<Tls> { impl JabberStream<Unencrypted> {
#[async_recursion] pub async fn starttls(mut self, domain: impl AsRef<str>) -> Result<Tls> {
pub async fn negotiate(mut self) -> Result<Jabber<Tls>> {
self.start_stream().await?;
let features = self.get_features().await?.features;
if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
self.auth.clone(),
features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_))),
) {
// TODO: avoid clone
self.sasl(mechanisms.clone(), sasl_config).await?;
let jabber = self.negotiate().await?;
Ok(jabber)
} else if let Some(Feature::Bind) = features
.iter()
.find(|feature| matches!(feature, Feature::Bind))
{
self.bind().await?;
Ok(self)
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
}
impl Jabber<Unencrypted> {
pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
self.writer self.writer
.write_full(&StartTls { required: false }) .write_full(&StartTls { required: false })
.await?; .await?;
@ -370,43 +368,31 @@ impl Jabber<Unencrypted> {
debug!("got proceed: {:?}", proceed); debug!("got proceed: {:?}", proceed);
let connector = TlsConnector::new().unwrap(); let connector = TlsConnector::new().unwrap();
let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.server, stream) .connect(domain.as_ref(), stream)
.await .await
{ {
let (read, write) = tokio::io::split(tlsstream); // let (read, write) = tokio::io::split(tlsstream);
let client = Jabber::new( // let client = JabberStream::new(read, write);
read, return Ok(tls_stream);
write,
self.jid.to_owned(),
self.auth.to_owned(),
self.server.to_owned(),
);
return Ok(client);
} else { } else {
return Err(Error::Connection); return Err(Error::Connection);
} }
} }
} }
impl std::fmt::Debug for Jabber<Tls> { impl std::fmt::Debug for JabberStream<Tls> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber") f.debug_struct("Jabber")
.field("connection", &"tls") .field("connection", &"tls")
.field("jid", &self.jid)
.field("auth", &self.auth)
.field("server", &self.server)
.finish() .finish()
} }
} }
impl std::fmt::Debug for Jabber<Unencrypted> { impl std::fmt::Debug for JabberStream<Unencrypted> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber") f.debug_struct("Jabber")
.field("connection", &"unencrypted") .field("connection", &"unencrypted")
.field("jid", &self.jid)
.field("auth", &self.auth)
.field("server", &self.server)
.finish() .finish()
} }
} }
@ -422,61 +408,61 @@ mod tests {
#[test(tokio::test)] #[test(tokio::test)]
async fn start_stream() { async fn start_stream() {
let connection = Connection::connect("blos.sm", None, None).await.unwrap(); // let connection = Connection::connect("blos.sm", None, None).await.unwrap();
match connection { // match connection {
Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
} // }
} }
#[test(tokio::test)] #[test(tokio::test)]
async fn sasl() { async fn sasl() {
let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
.await // .await
.unwrap() // .unwrap()
.ensure_tls() // .ensure_tls()
.await // .await
.unwrap(); // .unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text); // println!("data: {}", text);
jabber.start_stream().await.unwrap(); // jabber.start_stream().await.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text); // println!("data: {}", text);
jabber.reader.read_buf().await.unwrap(); // jabber.reader.read_buf().await.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text); // println!("data: {}", text);
let features = jabber.get_features().await.unwrap(); // let features = jabber.get_features().await.unwrap();
let (sasl_config, feature) = ( // let (sasl_config, feature) = (
jabber.auth.clone().unwrap(), // jabber.auth.clone().unwrap(),
features // features
.features // .features
.iter() // .iter()
.find(|feature| matches!(feature, Feature::Sasl(_))) // .find(|feature| matches!(feature, Feature::Sasl(_)))
.unwrap(), // .unwrap(),
); // );
match feature { // match feature {
Feature::StartTls(_start_tls) => todo!(), // Feature::StartTls(_start_tls) => todo!(),
Feature::Sasl(mechanisms) => { // Feature::Sasl(mechanisms) => {
jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
} // }
Feature::Bind => todo!(), // Feature::Bind => todo!(),
Feature::Unknown => todo!(), // Feature::Unknown => todo!(),
} // }
} }
#[tokio::test] #[tokio::test]
async fn negotiate() { async fn negotiate() {
let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
.await // .await
.unwrap() // .unwrap()
.ensure_tls() // .ensure_tls()
.await // .await
.unwrap() // .unwrap()
.negotiate() // .negotiate()
.await // .await
.unwrap(); // .unwrap();
sleep(Duration::from_secs(5)).await // sleep(Duration::from_secs(5)).await
} }
} }

View File

@ -2,6 +2,7 @@
// #![feature(let_chains)] // #![feature(let_chains)]
// TODO: logging (dropped errors) // TODO: logging (dropped errors)
pub mod client;
pub mod connection; pub mod connection;
pub mod error; pub mod error;
pub mod jabber; pub mod jabber;
@ -11,24 +12,25 @@ pub mod stanza;
pub use connection::Connection; pub use connection::Connection;
use connection::Tls; use connection::Tls;
pub use error::Error; pub use error::Error;
pub use jabber::Jabber; pub use jabber::JabberStream;
pub use jid::JID; pub use jid::JID;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
pub async fn login<J: AsRef<str>, P: AsRef<str>>(jid: J, password: P) -> Result<Jabber<Tls>> { pub async fn login<J: AsRef<str>, P: AsRef<str>>(jid: J, password: P) -> Result<JabberStream<Tls>> {
Ok(Connection::connect_user(jid, password.as_ref().to_string()) todo!()
.await? // Ok(Connection::connect_user(jid, password.as_ref().to_string())
.ensure_tls() // .await?
.await? // .ensure_tls()
.negotiate() // .await?
.await?) // .negotiate()
// .await?)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[tokio::test] // #[tokio::test]
async fn test_login() { // async fn test_login() {
crate::login("test@blos.sm/clown", "slayed").await.unwrap(); // crate::login("test@blos.sm/clown", "slayed").await.unwrap();
} // }
} }

View File

@ -1,7 +1,7 @@
use iq::Iq; use iq::Iq;
use message::Message; use message::Message;
use peanuts::{ use peanuts::{
element::{FromElement, IntoElement}, element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement},
DeserializeError, DeserializeError,
}; };
use presence::Presence; use presence::Presence;
@ -20,6 +20,18 @@ pub enum Stanza {
Presence(Presence), Presence(Presence),
Iq(Iq), Iq(Iq),
Error(StreamError), Error(StreamError),
OtherContent(Content),
}
impl FromContent for Stanza {
fn from_content(content: Content) -> peanuts::element::DeserializeResult<Self> {
match content {
Content::Element(element) => Ok(Stanza::from_element(element)?),
Content::Text(_) => Ok(Stanza::OtherContent(content)),
Content::PI => Ok(Stanza::OtherContent(content)),
Content::Comment(_) => Ok(Stanza::OtherContent(content)),
}
}
} }
impl FromElement for Stanza { impl FromElement for Stanza {
@ -36,13 +48,14 @@ impl FromElement for Stanza {
} }
} }
impl IntoElement for Stanza { impl IntoContent for Stanza {
fn builder(&self) -> peanuts::element::ElementBuilder { fn builder(&self) -> peanuts::element::ContentBuilder {
match self { match self {
Stanza::Message(message) => message.builder(), Stanza::Message(message) => <Message as IntoContent>::builder(message),
Stanza::Presence(presence) => presence.builder(), Stanza::Presence(presence) => <Presence as IntoContent>::builder(presence),
Stanza::Iq(iq) => iq.builder(), Stanza::Iq(iq) => <Iq as IntoContent>::builder(iq),
Stanza::Error(error) => error.builder(), Stanza::Error(error) => <StreamError as IntoContent>::builder(error),
Stanza::OtherContent(_content) => ContentBuilder::Comment("other-content".to_string()),
} }
} }
} }