lampada/src/jabber.rs

142 lines
5.0 KiB
Rust

use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use quick_xml::{Reader, Writer};
use rsasl::prelude::SASLConfig;
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
use crate::client::JabberClientType;
use crate::jid::JID;
use crate::{client, JabberClient};
use crate::{JabberError, Result};
pub struct Jabber<'j> {
pub jid: JID,
pub auth: Arc<SASLConfig>,
pub server: String,
_marker: PhantomData<&'j ()>,
}
impl<'j> Jabber<'j> {
pub fn new(jid: JID, password: String) -> Result<Self> {
let server = jid.domainpart.clone();
let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?;
println!("auth: {:?}", auth);
Ok(Self {
jid,
auth,
server,
_marker: PhantomData,
})
}
pub async fn login(&'j mut self) -> Result<JabberClient<'j>> {
let mut client = self.connect().await?.ensure_tls().await?;
println!("negotiation");
client.negotiate().await?;
Ok(client)
}
async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> {
let mut socket_addrs = Vec::new();
// if it's a socket/ip then just return that
// socket
if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) {
match socket_addr.port() {
5223 => socket_addrs.push((socket_addr, true)),
_ => socket_addrs.push((socket_addr, false)),
}
return socket_addrs;
}
// ip
if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) {
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
return socket_addrs;
}
// otherwise resolve
if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
}
});
}
}
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
}
});
}
}
// in case cannot connect through SRV records
resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
}
});
}
socket_addrs
}
pub async fn connect(&'j mut self) -> Result<JabberClientType> {
for (socket_addr, is_tls) in self.get_sockets().await {
println!("trying {}", socket_addr);
match is_tls {
true => {
let socket = TcpStream::connect(socket_addr).await.unwrap();
let connector = TlsConnector::new().unwrap();
if let Ok(stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.server, socket)
.await
{
let (read, writer) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
return Ok(JabberClientType::Encrypted(
client::encrypted::JabberClient::new(reader, writer, self),
));
}
}
false => {
if let Ok(stream) = TcpStream::connect(socket_addr).await {
let (read, write) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
return Ok(JabberClientType::Unencrypted(
client::unencrypted::JabberClient::new(reader, writer, self),
));
}
}
}
}
Err(JabberError::Connection)
}
}