2023-10-21 01:28:54 +01:00
|
|
|
use std::net::{IpAddr, SocketAddr};
|
|
|
|
use std::str;
|
|
|
|
use std::str::FromStr;
|
2024-11-29 02:11:02 +00:00
|
|
|
use std::sync::Arc;
|
2023-10-21 01:28:54 +01:00
|
|
|
|
2024-11-29 02:11:02 +00:00
|
|
|
use rsasl::config::SASLConfig;
|
2023-10-21 01:28:54 +01:00
|
|
|
use tokio::net::TcpStream;
|
|
|
|
use tokio_native_tls::native_tls::TlsConnector;
|
|
|
|
// TODO: use rustls
|
|
|
|
use tokio_native_tls::TlsStream;
|
2023-10-21 02:38:19 +01:00
|
|
|
use tracing::{debug, info, instrument, trace};
|
2023-10-21 01:28:54 +01:00
|
|
|
|
|
|
|
use crate::Jabber;
|
|
|
|
use crate::Result;
|
2024-11-29 02:11:02 +00:00
|
|
|
use crate::{Error, JID};
|
2023-10-21 01:28:54 +01:00
|
|
|
|
|
|
|
pub type Tls = TlsStream<TcpStream>;
|
|
|
|
pub type Unencrypted = TcpStream;
|
|
|
|
|
2023-10-28 21:06:42 +01:00
|
|
|
#[derive(Debug)]
|
2023-10-21 01:28:54 +01:00
|
|
|
pub enum Connection {
|
|
|
|
Encrypted(Jabber<Tls>),
|
|
|
|
Unencrypted(Jabber<Unencrypted>),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Connection {
|
2023-10-28 21:06:42 +01:00
|
|
|
#[instrument]
|
2024-11-28 19:06:20 +00:00
|
|
|
/// stream not started
|
2023-10-21 01:28:54 +01:00
|
|
|
pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
|
|
|
|
match self {
|
|
|
|
Connection::Encrypted(j) => Ok(j),
|
2023-10-28 21:06:42 +01:00
|
|
|
Connection::Unencrypted(mut j) => {
|
2024-11-24 02:04:45 +00:00
|
|
|
j.start_stream().await?;
|
2023-10-28 21:06:42 +01:00
|
|
|
info!("upgrading connection to tls");
|
2024-11-24 02:04:45 +00:00
|
|
|
j.get_features().await?;
|
|
|
|
let j = j.starttls().await?;
|
|
|
|
Ok(j)
|
2023-10-28 21:06:42 +01:00
|
|
|
}
|
2023-10-21 01:28:54 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-11-29 02:11:02 +00:00
|
|
|
pub async fn connect_user(jid: impl AsRef<str>, password: String) -> Result<Self> {
|
|
|
|
let jid: JID = JID::from_str(jid.as_ref())?;
|
|
|
|
let server = jid.domainpart.clone();
|
|
|
|
let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
|
|
|
|
println!("auth: {:?}", auth);
|
|
|
|
Self::connect(&server, Some(jid), Some(auth)).await
|
|
|
|
}
|
2023-10-21 01:28:54 +01:00
|
|
|
|
2023-10-21 02:38:19 +01:00
|
|
|
#[instrument]
|
2024-11-29 02:11:02 +00:00
|
|
|
pub async fn connect(
|
|
|
|
server: &str,
|
|
|
|
jid: Option<JID>,
|
|
|
|
auth: Option<Arc<SASLConfig>>,
|
|
|
|
) -> Result<Self> {
|
2023-10-21 02:38:19 +01:00
|
|
|
info!("connecting to {}", server);
|
2023-10-21 01:28:54 +01:00
|
|
|
let sockets = Self::get_sockets(&server).await;
|
2023-10-21 02:38:19 +01:00
|
|
|
debug!("discovered sockets: {:?}", sockets);
|
2023-10-21 01:28:54 +01:00
|
|
|
for (socket_addr, tls) in sockets {
|
|
|
|
match tls {
|
|
|
|
true => {
|
|
|
|
if let Ok(connection) = Self::connect_tls(socket_addr, &server).await {
|
2023-10-21 02:38:19 +01:00
|
|
|
info!("connected via encrypted stream to {}", socket_addr);
|
2023-10-21 01:28:54 +01:00
|
|
|
let (readhalf, writehalf) = tokio::io::split(connection);
|
|
|
|
return Ok(Self::Encrypted(Jabber::new(
|
|
|
|
readhalf,
|
|
|
|
writehalf,
|
2024-11-29 02:11:02 +00:00
|
|
|
jid,
|
|
|
|
auth,
|
2023-10-21 01:28:54 +01:00
|
|
|
server.to_owned(),
|
|
|
|
)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
false => {
|
|
|
|
if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
|
2023-10-21 02:38:19 +01:00
|
|
|
info!("connected via unencrypted stream to {}", socket_addr);
|
2023-10-21 01:28:54 +01:00
|
|
|
let (readhalf, writehalf) = tokio::io::split(connection);
|
|
|
|
return Ok(Self::Unencrypted(Jabber::new(
|
|
|
|
readhalf,
|
|
|
|
writehalf,
|
2024-11-29 02:11:02 +00:00
|
|
|
jid,
|
|
|
|
auth,
|
2023-10-21 01:28:54 +01:00
|
|
|
server.to_owned(),
|
|
|
|
)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-11-23 22:39:44 +00:00
|
|
|
Err(Error::Connection)
|
2023-10-21 01:28:54 +01:00
|
|
|
}
|
|
|
|
|
2023-10-21 02:38:19 +01:00
|
|
|
#[instrument]
|
|
|
|
async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> {
|
2023-10-21 01:28:54 +01:00
|
|
|
let mut socket_addrs = Vec::new();
|
|
|
|
|
|
|
|
// if it's a socket/ip then just return that
|
|
|
|
|
|
|
|
// socket
|
2023-10-21 02:38:19 +01:00
|
|
|
trace!("checking if address is a socket address");
|
|
|
|
if let Ok(socket_addr) = SocketAddr::from_str(address) {
|
|
|
|
debug!("{} is a socket address", address);
|
2023-10-21 01:28:54 +01:00
|
|
|
match socket_addr.port() {
|
|
|
|
5223 => socket_addrs.push((socket_addr, true)),
|
|
|
|
_ => socket_addrs.push((socket_addr, false)),
|
|
|
|
}
|
|
|
|
|
|
|
|
return socket_addrs;
|
|
|
|
}
|
|
|
|
// ip
|
2023-10-21 02:38:19 +01:00
|
|
|
trace!("checking if address is an ip");
|
|
|
|
if let Ok(ip) = IpAddr::from_str(address) {
|
|
|
|
debug!("{} is an ip", address);
|
2023-10-21 01:28:54 +01:00
|
|
|
socket_addrs.push((SocketAddr::new(ip, 5222), false));
|
|
|
|
socket_addrs.push((SocketAddr::new(ip, 5223), true));
|
|
|
|
return socket_addrs;
|
|
|
|
}
|
|
|
|
|
|
|
|
// otherwise resolve
|
2023-10-21 02:38:19 +01:00
|
|
|
debug!("resolving {}", address);
|
2023-10-21 01:28:54 +01:00
|
|
|
if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
|
|
|
|
if let Ok(lookup) = resolver
|
2023-10-21 02:38:19 +01:00
|
|
|
.srv_lookup(format!("_xmpp-client._tcp.{}", address))
|
2023-10-21 01:28:54 +01:00
|
|
|
.await
|
|
|
|
{
|
|
|
|
for srv in lookup {
|
|
|
|
resolver
|
|
|
|
.lookup_ip(srv.target().to_owned())
|
|
|
|
.await
|
|
|
|
.map(|ips| {
|
|
|
|
for ip in ips {
|
|
|
|
socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if let Ok(lookup) = resolver
|
2023-10-21 02:38:19 +01:00
|
|
|
.srv_lookup(format!("_xmpps-client._tcp.{}", address))
|
2023-10-21 01:28:54 +01:00
|
|
|
.await
|
|
|
|
{
|
|
|
|
for srv in lookup {
|
|
|
|
resolver
|
|
|
|
.lookup_ip(srv.target().to_owned())
|
|
|
|
.await
|
|
|
|
.map(|ips| {
|
|
|
|
for ip in ips {
|
|
|
|
socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// in case cannot connect through SRV records
|
2023-10-21 02:38:19 +01:00
|
|
|
resolver.lookup_ip(address).await.map(|ips| {
|
2023-10-21 01:28:54 +01:00
|
|
|
for ip in ips {
|
|
|
|
socket_addrs.push((SocketAddr::new(ip, 5222), false));
|
|
|
|
socket_addrs.push((SocketAddr::new(ip, 5223), true));
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
socket_addrs
|
|
|
|
}
|
|
|
|
|
|
|
|
/// establishes a connection to the server
|
2023-10-21 02:38:19 +01:00
|
|
|
#[instrument]
|
2023-10-21 01:28:54 +01:00
|
|
|
pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
|
|
|
|
let socket = TcpStream::connect(socket_addr)
|
|
|
|
.await
|
2024-11-23 22:39:44 +00:00
|
|
|
.map_err(|_| Error::Connection)?;
|
|
|
|
let connector = TlsConnector::new().map_err(|_| Error::Connection)?;
|
2023-10-21 01:28:54 +01:00
|
|
|
tokio_native_tls::TlsConnector::from(connector)
|
|
|
|
.connect(domain_name, socket)
|
|
|
|
.await
|
2024-11-23 22:39:44 +00:00
|
|
|
.map_err(|_| Error::Connection)
|
2023-10-21 01:28:54 +01:00
|
|
|
}
|
|
|
|
|
2023-10-21 02:38:19 +01:00
|
|
|
#[instrument]
|
2023-10-21 01:28:54 +01:00
|
|
|
pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
|
|
|
|
TcpStream::connect(socket_addr)
|
|
|
|
.await
|
2024-11-23 22:39:44 +00:00
|
|
|
.map_err(|_| Error::Connection)
|
2023-10-21 01:28:54 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use super::*;
|
2023-10-21 02:38:19 +01:00
|
|
|
use test_log::test;
|
2023-10-21 01:28:54 +01:00
|
|
|
|
2023-10-21 02:38:19 +01:00
|
|
|
#[test(tokio::test)]
|
2023-10-21 01:28:54 +01:00
|
|
|
async fn connect() {
|
2024-11-29 02:11:02 +00:00
|
|
|
Connection::connect("blos.sm", None, None).await.unwrap();
|
2023-10-21 01:28:54 +01:00
|
|
|
}
|
2024-11-24 02:04:45 +00:00
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn test_tls() {
|
2024-11-29 02:11:02 +00:00
|
|
|
Connection::connect("blos.sm", None, None)
|
2024-11-24 02:04:45 +00:00
|
|
|
.await
|
|
|
|
.unwrap()
|
|
|
|
.ensure_tls()
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
}
|
2023-10-21 01:28:54 +01:00
|
|
|
}
|