165 lines
4.6 KiB
Rust
165 lines
4.6 KiB
Rust
use std::str;
|
|
use std::sync::Arc;
|
|
|
|
use peanuts::element::{FromElement, IntoElement};
|
|
use peanuts::{Reader, Writer};
|
|
use rsasl::prelude::SASLConfig;
|
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
|
use tokio_native_tls::native_tls::TlsConnector;
|
|
use tracing::{debug, info, instrument, trace};
|
|
use trust_dns_resolver::proto::rr::domain::IntoLabel;
|
|
|
|
use crate::connection::{Tls, Unencrypted};
|
|
use crate::error::Error;
|
|
use crate::stanza::starttls::{Proceed, StartTls};
|
|
use crate::stanza::stream::{Features, Stream};
|
|
use crate::stanza::XML_VERSION;
|
|
use crate::Result;
|
|
use crate::JID;
|
|
|
|
pub struct Jabber<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
reader: Reader<ReadHalf<S>>,
|
|
writer: Writer<WriteHalf<S>>,
|
|
jid: Option<JID>,
|
|
auth: Option<Arc<SASLConfig>>,
|
|
server: String,
|
|
}
|
|
|
|
impl<S> Jabber<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
pub fn new(
|
|
reader: ReadHalf<S>,
|
|
writer: WriteHalf<S>,
|
|
jid: Option<JID>,
|
|
auth: Option<Arc<SASLConfig>>,
|
|
server: String,
|
|
) -> Self {
|
|
let reader = Reader::new(reader);
|
|
let writer = Writer::new(writer);
|
|
Self {
|
|
reader,
|
|
writer,
|
|
jid,
|
|
auth,
|
|
server,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S> Jabber<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
|
Jabber<S>: std::fmt::Debug,
|
|
{
|
|
// pub async fn negotiate(self) -> Result<Jabber<S>> {}
|
|
|
|
#[instrument]
|
|
pub async fn start_stream(&mut self) -> Result<()> {
|
|
// client to server
|
|
|
|
// declaration
|
|
self.writer.write_declaration(XML_VERSION).await?;
|
|
|
|
// opening stream element
|
|
let server = self.server.clone().try_into()?;
|
|
let stream = Stream::new_client(None, server, None, "en".to_string());
|
|
self.writer.write_start(&stream).await?;
|
|
|
|
// server to client
|
|
|
|
// may or may not send a declaration
|
|
let decl = self.reader.read_prolog().await?;
|
|
|
|
// receive stream element and validate
|
|
let stream: Stream = self.reader.read_start().await?;
|
|
debug!("got stream: {:?}", stream);
|
|
if let Some(from) = stream.from {
|
|
self.server = from.to_string()
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn get_features(&mut self) -> Result<Features> {
|
|
debug!("getting features");
|
|
let features: Features = self.reader.read().await?;
|
|
debug!("got features: {:?}", features);
|
|
Ok(features)
|
|
}
|
|
|
|
pub fn into_inner(self) -> S {
|
|
self.reader.into_inner().unsplit(self.writer.into_inner())
|
|
}
|
|
}
|
|
|
|
impl Jabber<Unencrypted> {
|
|
pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
|
|
self.writer
|
|
.write_full(&StartTls { required: false })
|
|
.await?;
|
|
let proceed: Proceed = self.reader.read().await?;
|
|
debug!("got proceed: {:?}", proceed);
|
|
let connector = TlsConnector::new().unwrap();
|
|
let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
|
|
if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
|
|
.connect(&self.server, stream)
|
|
.await
|
|
{
|
|
let (read, write) = tokio::io::split(tlsstream);
|
|
let client = Jabber::new(
|
|
read,
|
|
write,
|
|
self.jid.to_owned(),
|
|
self.auth.to_owned(),
|
|
self.server.to_owned(),
|
|
);
|
|
return Ok(client);
|
|
} else {
|
|
return Err(Error::Connection);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Debug for Jabber<Tls> {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("Jabber")
|
|
.field("connection", &"tls")
|
|
.field("jid", &self.jid)
|
|
.field("auth", &self.auth)
|
|
.field("server", &self.server)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Debug for Jabber<Unencrypted> {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("Jabber")
|
|
.field("connection", &"unencrypted")
|
|
.field("jid", &self.jid)
|
|
.field("auth", &self.auth)
|
|
.field("server", &self.server)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::connection::Connection;
|
|
use test_log::test;
|
|
|
|
#[test(tokio::test)]
|
|
async fn start_stream() {
|
|
let connection = Connection::connect("blos.sm").await.unwrap();
|
|
match connection {
|
|
Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
|
|
Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
|
|
}
|
|
}
|
|
}
|