tls/starttls
This commit is contained in:
parent
f00fa833e0
commit
e7cf44efe1
|
@ -1322,9 +1322,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pemfile"
|
name = "rustls-pemfile"
|
||||||
version = "1.0.2"
|
version = "1.0.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
|
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64",
|
||||||
]
|
]
|
||||||
|
@ -1360,6 +1360,7 @@ dependencies = [
|
||||||
"quick-xml",
|
"quick-xml",
|
||||||
"rcgen",
|
"rcgen",
|
||||||
"rsdns",
|
"rsdns",
|
||||||
|
"rustls-pemfile",
|
||||||
"serde",
|
"serde",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
|
|
|
@ -21,3 +21,4 @@ instant-acme = "0.3.2"
|
||||||
desec = { path = "../desec" }
|
desec = { path = "../desec" }
|
||||||
rcgen = "0.11.1"
|
rcgen = "0.11.1"
|
||||||
rsdns = { version = "0.15.0", features = ["net-tokio"] }
|
rsdns = { version = "0.15.0", features = ["net-tokio"] }
|
||||||
|
rustls-pemfile = "1.0.3"
|
||||||
|
|
|
@ -1,21 +1,19 @@
|
||||||
use std::{
|
use std::{
|
||||||
|
env,
|
||||||
fs::File,
|
fs::File,
|
||||||
io::{prelude::Write, Read},
|
io::{prelude::Write, Read},
|
||||||
net::{SocketAddr, ToSocketAddrs},
|
net::ToSocketAddrs,
|
||||||
str::FromStr,
|
os::unix::prelude::PermissionsExt,
|
||||||
time::Duration,
|
time::Duration,
|
||||||
vec,
|
vec,
|
||||||
};
|
};
|
||||||
|
|
||||||
use desec::{
|
use desec::dns::{RRSet, RRSetPatch, Record};
|
||||||
dns::{RRSet, RRSetPatch, Record},
|
|
||||||
Session,
|
|
||||||
};
|
|
||||||
use instant_acme::{
|
use instant_acme::{
|
||||||
Account, AuthorizationStatus, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
|
Account, AuthorizationStatus, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
|
||||||
Order, OrderStatus,
|
Order, OrderStatus,
|
||||||
};
|
};
|
||||||
use log::{debug, error, info, warn};
|
use log::{debug, error, info};
|
||||||
use rcgen::{Certificate, CertificateParams, DistinguishedName};
|
use rcgen::{Certificate, CertificateParams, DistinguishedName};
|
||||||
use rsdns::{
|
use rsdns::{
|
||||||
clients::{tokio::Client, ClientConfig},
|
clients::{tokio::Client, ClientConfig},
|
||||||
|
@ -23,18 +21,21 @@ use rsdns::{
|
||||||
records::data,
|
records::data,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tokio_rustls::rustls;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub enum CertStore {
|
pub enum CertStore {
|
||||||
Provision,
|
Provision,
|
||||||
Existing(CertificatePEM),
|
Existing(CertPaths),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub domain: String,
|
pub domain: String,
|
||||||
pub subdomain: Option<String>,
|
pub subdomain: Option<String>,
|
||||||
pub port: u16,
|
pub insecure_port: u16,
|
||||||
|
pub tls_port: u16,
|
||||||
pub cert_store: CertStore,
|
pub cert_store: CertStore,
|
||||||
pub desec_cfg: DesecConfig,
|
pub desec_cfg: DesecConfig,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
|
@ -44,18 +45,56 @@ pub struct Config {
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct DesecConfig {
|
pub struct DesecConfig {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub password: String,
|
#[serde(skip)]
|
||||||
|
pub password: Option<String>,
|
||||||
pub name_servers: Vec<String>,
|
pub name_servers: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct CertificatePEM {
|
pub struct CertPaths {
|
||||||
pub cert_chain_pem: String,
|
pub cert_chain_path: String,
|
||||||
pub private_key_pem: String,
|
pub private_key_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CertPaths {
|
||||||
|
pub fn into_config(&self) -> Result<rustls::ServerConfig, anyhow::Error> {
|
||||||
|
let mut buf = std::io::BufReader::new(File::open(&self.cert_chain_path)?);
|
||||||
|
let certs = rustls_pemfile::certs(&mut buf)?
|
||||||
|
.into_iter()
|
||||||
|
.map(tokio_rustls::rustls::Certificate)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut reader = std::io::BufReader::new(File::open(&self.private_key_path)?);
|
||||||
|
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
|
||||||
|
let key = match keys.len() {
|
||||||
|
0 => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"No PKCS8-encoded private key found in {}",
|
||||||
|
&self.private_key_path,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
1 => rustls::PrivateKey(keys.remove(0)),
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"More than one PKCS8-encoded private key found in {}",
|
||||||
|
&self.private_key_path,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(rustls::ServerConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(certs, key)?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const CONFIG_PATHS: [&str; 3] = [DEFAULT_PATH, "/etc/salut.toml", "/usr/local/etc/salut.toml"];
|
const CONFIG_PATHS: [&str; 3] = [DEFAULT_PATH, "/etc/salut.toml", "/usr/local/etc/salut.toml"];
|
||||||
|
const CERT_PATHS: [&str; 3] = ["/etc/salut/", "/usr/local/etc/salut/", "./"];
|
||||||
|
const CERT_FILENAME: &str = "salut.cert";
|
||||||
|
const CERT_PRIVKEY_FILENAME: &str = "salut.pk";
|
||||||
const DNS_QUERY_WAIT: Duration = Duration::from_millis(250);
|
const DNS_QUERY_WAIT: Duration = Duration::from_millis(250);
|
||||||
|
const DNS_PASSWORD_VAR: &str = "DESEC_PASSWORD";
|
||||||
pub const DEFAULT_PATH: &str = "salut.toml";
|
pub const DEFAULT_PATH: &str = "salut.toml";
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
|
@ -63,11 +102,12 @@ impl Default for Config {
|
||||||
Self {
|
Self {
|
||||||
domain: String::new(),
|
domain: String::new(),
|
||||||
subdomain: Some(String::new()),
|
subdomain: Some(String::new()),
|
||||||
port: 5222,
|
insecure_port: 5222,
|
||||||
|
tls_port: 5223,
|
||||||
cert_store: CertStore::Provision,
|
cert_store: CertStore::Provision,
|
||||||
desec_cfg: DesecConfig {
|
desec_cfg: DesecConfig {
|
||||||
username: String::new(),
|
username: String::new(),
|
||||||
password: String::new(),
|
password: None,
|
||||||
name_servers: vec!["ns1.desec.io".into(), "ns2.desec.org".into()],
|
name_servers: vec!["ns1.desec.io".into(), "ns2.desec.org".into()],
|
||||||
},
|
},
|
||||||
original_path: DEFAULT_PATH,
|
original_path: DEFAULT_PATH,
|
||||||
|
@ -83,6 +123,9 @@ impl Config {
|
||||||
file.read_to_string(&mut cfg)?;
|
file.read_to_string(&mut cfg)?;
|
||||||
let mut cfg: Self = toml::from_str(&cfg)?;
|
let mut cfg: Self = toml::from_str(&cfg)?;
|
||||||
cfg.original_path = path;
|
cfg.original_path = path;
|
||||||
|
if let Ok(pass) = env::var(DNS_PASSWORD_VAR) {
|
||||||
|
cfg.desec_cfg.password = Some(pass);
|
||||||
|
}
|
||||||
return Ok(cfg);
|
return Ok(cfg);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -111,7 +154,7 @@ const ACME_PREFIX: &str = "_acme-challenge";
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
// Returns existing certificate or provisions a new one via DNS challenge using DeSEC
|
// Returns existing certificate or provisions a new one via DNS challenge using DeSEC
|
||||||
pub async fn certificate(&self) -> Result<CertificatePEM, anyhow::Error> {
|
pub async fn certificate(&self) -> Result<CertPaths, anyhow::Error> {
|
||||||
let desec_cfg = match self.cert_store.clone() {
|
let desec_cfg = match self.cert_store.clone() {
|
||||||
CertStore::Provision => self.desec_cfg.clone(),
|
CertStore::Provision => self.desec_cfg.clone(),
|
||||||
CertStore::Existing(existing) => return Ok(existing),
|
CertStore::Existing(existing) => return Ok(existing),
|
||||||
|
@ -141,7 +184,11 @@ impl Config {
|
||||||
assert!(matches!(state.status, OrderStatus::Pending));
|
assert!(matches!(state.status, OrderStatus::Pending));
|
||||||
|
|
||||||
debug!("logging into desec as <{}>", &desec_cfg.username);
|
debug!("logging into desec as <{}>", &desec_cfg.username);
|
||||||
let dns = desec::Session::login(&desec_cfg.username, &desec_cfg.password).await?;
|
let desec_pass = match desec_cfg.password {
|
||||||
|
Some(pass) => pass,
|
||||||
|
None => panic!("need desec password for provisioning the TLS certificate, please set {DNS_PASSWORD_VAR}"),
|
||||||
|
};
|
||||||
|
let dns = desec::Session::login(&desec_cfg.username, &desec_pass).await?;
|
||||||
debug!("querying existing TXT records");
|
debug!("querying existing TXT records");
|
||||||
let existing_records: Vec<RRSet> = dns
|
let existing_records: Vec<RRSet> = dns
|
||||||
.get_rrsets(&self.domain, Some(vec![Record::TXT]))
|
.get_rrsets(&self.domain, Some(vec![Record::TXT]))
|
||||||
|
@ -258,16 +305,12 @@ impl Config {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let cert = CertificatePEM {
|
|
||||||
cert_chain_pem,
|
|
||||||
private_key_pem: cert.serialize_private_key_pem(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut new_cfg = self.clone();
|
let mut new_cfg = self.clone();
|
||||||
new_cfg.cert_store = CertStore::Existing(cert.clone());
|
let paths = save_certs_pam(cert_chain_pem, cert.serialize_private_key_pem()).await?;
|
||||||
|
new_cfg.cert_store = CertStore::Existing(paths.clone());
|
||||||
new_cfg.save(self.original_path)?;
|
new_cfg.save(self.original_path)?;
|
||||||
|
|
||||||
Ok(cert)
|
Ok(paths)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn wait_challenges(
|
async fn wait_challenges(
|
||||||
|
@ -362,3 +405,36 @@ impl Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
async fn save_certs_pam(
|
||||||
|
cert_chain: String,
|
||||||
|
private_key: String,
|
||||||
|
) -> Result<CertPaths, anyhow::Error> {
|
||||||
|
for try_path in CERT_PATHS {
|
||||||
|
if let Err(_) = std::fs::create_dir_all(try_path) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let cert_chain_path = format!("{try_path}{CERT_FILENAME}");
|
||||||
|
let mut cert_file = match tokio::fs::File::create(&cert_chain_path).await {
|
||||||
|
Ok(f) => f,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
let private_key_path = format!("{try_path}{CERT_PRIVKEY_FILENAME}");
|
||||||
|
let mut pk_file = tokio::fs::File::create(&private_key_path).await?;
|
||||||
|
pk_file.metadata().await?.permissions().set_mode(600);
|
||||||
|
cert_file.metadata().await?.permissions().set_mode(600);
|
||||||
|
|
||||||
|
cert_file.write_all(cert_chain.as_bytes()).await?;
|
||||||
|
pk_file.write_all(private_key.as_bytes()).await?;
|
||||||
|
cert_file.flush().await?;
|
||||||
|
pk_file.flush().await?;
|
||||||
|
|
||||||
|
return Ok(CertPaths {
|
||||||
|
cert_chain_path,
|
||||||
|
private_key_path,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow::anyhow!(
|
||||||
|
"could not create/save cert files in {CERT_PATHS:?}"
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use enum_display::EnumDisplay;
|
use enum_display::EnumDisplay;
|
||||||
use std::string::FromUtf8Error;
|
use std::string::FromUtf8Error;
|
||||||
|
use tokio_rustls::rustls;
|
||||||
|
|
||||||
use log::error;
|
use log::error;
|
||||||
use quick_xml::events::attributes::AttrError;
|
use quick_xml::events::attributes::AttrError;
|
||||||
|
@ -126,6 +127,37 @@ pub enum StreamError {
|
||||||
/// the server.
|
/// the server.
|
||||||
UnsupportedVersion,
|
UnsupportedVersion,
|
||||||
}
|
}
|
||||||
|
impl From<rustls::Error> for StreamError {
|
||||||
|
fn from(value: rustls::Error) -> Self {
|
||||||
|
match value {
|
||||||
|
rustls::Error::InappropriateMessage {
|
||||||
|
expect_types,
|
||||||
|
got_type,
|
||||||
|
} => Self::BadFormat,
|
||||||
|
rustls::Error::InappropriateHandshakeMessage {
|
||||||
|
expect_types,
|
||||||
|
got_type,
|
||||||
|
} => Self::BadFormat,
|
||||||
|
rustls::Error::InvalidMessage(_) => Self::BadFormat,
|
||||||
|
rustls::Error::NoCertificatesPresented => Self::InternalServerError,
|
||||||
|
rustls::Error::UnsupportedNameType => Self::BadFormat,
|
||||||
|
rustls::Error::DecryptError => Self::InternalServerError,
|
||||||
|
rustls::Error::EncryptError => Self::InternalServerError,
|
||||||
|
rustls::Error::PeerIncompatible(_) => Self::BadFormat,
|
||||||
|
rustls::Error::PeerMisbehaved(_) => Self::BadFormat,
|
||||||
|
rustls::Error::AlertReceived(_) => Self::BadFormat,
|
||||||
|
rustls::Error::InvalidCertificate(_) => Self::InternalServerError,
|
||||||
|
_ => Self::BadFormat,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for StreamError {
|
||||||
|
fn from(value: std::io::Error) -> Self {
|
||||||
|
error!("io error: {value}");
|
||||||
|
Self::InternalServerError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<FromUtf8Error> for StreamError {
|
impl From<FromUtf8Error> for StreamError {
|
||||||
fn from(_: FromUtf8Error) -> Self {
|
fn from(_: FromUtf8Error) -> Self {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::process;
|
use std::{process, sync::Arc};
|
||||||
|
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
|
|
||||||
|
@ -33,10 +33,27 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
info!("checking for certificates");
|
info!("checking for certificates");
|
||||||
let certs = cfg.certificate().await.expect("getting certificates");
|
let certs = Arc::new(
|
||||||
|
cfg.certificate()
|
||||||
|
.await
|
||||||
|
.expect("getting certificates")
|
||||||
|
.into_config()?,
|
||||||
|
);
|
||||||
|
|
||||||
let host = cfg.hostname();
|
let host = cfg.hostname();
|
||||||
info!("listening on {host}:{}!", cfg.port);
|
info!(
|
||||||
server::listen(host, cfg.port).await.unwrap();
|
"listening on {host} {} (plain) and {} (tls)!",
|
||||||
|
cfg.insecure_port, cfg.tls_port
|
||||||
|
);
|
||||||
|
let cert_clone = certs.clone();
|
||||||
|
let host_clone = host.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
server::listen_tls(host_clone, cfg.tls_port, cert_clone)
|
||||||
|
.await
|
||||||
|
.expect("TLS listener")
|
||||||
|
});
|
||||||
|
server::listen_starttls(host, cfg.insecure_port, certs)
|
||||||
|
.await
|
||||||
|
.expect("STARTTLS listener");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use quick_xml::{
|
use quick_xml::{
|
||||||
events::{BytesStart, Event},
|
events::{BytesStart, Event},
|
||||||
Writer,
|
Writer,
|
||||||
};
|
};
|
||||||
use tokio::io::{AsyncBufRead, AsyncWrite};
|
use tokio::{io::AsyncWrite, net::TcpStream};
|
||||||
use tokio_rustls::rustls;
|
use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::StreamError,
|
error::StreamError,
|
||||||
|
streamstart::{StartTLSResult, StreamStart},
|
||||||
tag::{self, Tag},
|
tag::{self, Tag},
|
||||||
|
tls::stream::{self, TLSStream},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub enum Step {
|
pub enum Step {
|
||||||
|
@ -36,15 +40,20 @@ impl Tag for Step {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start_tls<R, W>(
|
pub(crate) enum TcpConnOrTLS {
|
||||||
reader: R,
|
TCP(TcpStream, StreamError),
|
||||||
writer: W,
|
/// Failed is returned when there's a failure starting the TLS stream
|
||||||
|
/// thus the TCP stream is consumed.
|
||||||
|
Failed(StreamError),
|
||||||
|
TLS(TlsStream<TcpStream>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn start_tls(
|
||||||
|
conn: TcpStream,
|
||||||
start_tls_event: BytesStart<'_>,
|
start_tls_event: BytesStart<'_>,
|
||||||
) -> Result<Step, StreamError>
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
where
|
) -> TcpConnOrTLS {
|
||||||
R: AsyncBufRead + Unpin,
|
let mut conn = conn;
|
||||||
W: AsyncWrite + Unpin + Send,
|
|
||||||
{
|
|
||||||
match start_tls_event.try_get_attribute("xmlns") {
|
match start_tls_event.try_get_attribute("xmlns") {
|
||||||
Ok(namespace) => {
|
Ok(namespace) => {
|
||||||
if &namespace
|
if &namespace
|
||||||
|
@ -52,28 +61,26 @@ where
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
!= tag::TLS_NAMESPACE
|
!= tag::TLS_NAMESPACE
|
||||||
{
|
{
|
||||||
return Ok(Step::Failure);
|
if let Err(err) = Step::Failure.write_tag(&mut conn).await {
|
||||||
|
return TcpConnOrTLS::TCP(conn, err);
|
||||||
|
}
|
||||||
|
return TcpConnOrTLS::TCP(conn, StreamError::InvalidNamespace);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(_) => return Ok(Step::Failure),
|
Err(_) => {
|
||||||
|
if let Err(err) = Step::Failure.write_tag(&mut conn).await {
|
||||||
|
return TcpConnOrTLS::TCP(conn, err);
|
||||||
|
}
|
||||||
|
return TcpConnOrTLS::TCP(conn, StreamError::BadFormat);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// let config = rustls::ServerConfig::builder()
|
if let Err(err) = Step::Proceed.write_tag(&mut conn).await {
|
||||||
// .with_safe_defaults()
|
return TcpConnOrTLS::TCP(conn, err);
|
||||||
// .with_no_client_auth()
|
};
|
||||||
// .with_single_cert(certs, keys.remove(0))
|
let acceptor = TlsAcceptor::from(tls_config);
|
||||||
// .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
match acceptor.accept(conn).await {
|
||||||
|
Ok(tls_conn) => TcpConnOrTLS::TLS(tls_conn),
|
||||||
Step::Proceed.write_tag(writer).await?;
|
Err(err) => TcpConnOrTLS::Failed(err.into()),
|
||||||
|
}
|
||||||
// match TlsConnector::builder(). {
|
|
||||||
// Ok(conn) => conn.,
|
|
||||||
// Err(err) => {
|
|
||||||
// error!("getting a tls connector: {err}");
|
|
||||||
// return Ok(Step::Failure);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
std::thread::sleep(std::time::Duration::from_secs(3));
|
|
||||||
todo!()
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,53 @@
|
||||||
use log::info;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use log::{error, info};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
use tokio_rustls::{rustls, TlsAcceptor};
|
||||||
|
|
||||||
use crate::streamstart;
|
use crate::{streamstart, tls::stream};
|
||||||
|
|
||||||
pub async fn listen(hostname: String, port: u16) -> Result<(), anyhow::Error> {
|
pub async fn listen_starttls(
|
||||||
let listener = TcpListener::bind(("0.0.0.0", port)).await?;
|
hostname: String,
|
||||||
|
insecure_port: u16,
|
||||||
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
let listener = TcpListener::bind(("0.0.0.0", insecure_port)).await?;
|
||||||
loop {
|
loop {
|
||||||
match listener.accept().await {
|
match listener.accept().await {
|
||||||
Ok(conn) => {
|
Ok(conn) => {
|
||||||
info!("opening connection from {}", conn.1);
|
info!("opening starttls connection from {}", conn.1);
|
||||||
streamstart::spawn(hostname.clone(), conn);
|
streamstart::spawn(hostname.clone(), conn, tls_config.clone());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("listening: {e}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn listen_tls(
|
||||||
|
hostname: String,
|
||||||
|
tls_port: u16,
|
||||||
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
let listener = TcpListener::bind(("0.0.0.0", tls_port)).await?;
|
||||||
|
let acceptor = TlsAcceptor::from(tls_config);
|
||||||
|
loop {
|
||||||
|
match listener.accept().await {
|
||||||
|
Ok(conn) => {
|
||||||
|
info!("opening TLS connection from {}", conn.1);
|
||||||
|
match acceptor.accept(conn.0).await {
|
||||||
|
Ok(conn) => {
|
||||||
|
let cloned_hostname = hostname.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
stream::TLSStream::new(conn, cloned_hostname)
|
||||||
|
.start_stream()
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(err) => error!("TLS accept error for {}: {err}", conn.1),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("listening: {e}");
|
eprintln!("listening: {e}");
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::net::SocketAddr;
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
use quick_xml::{
|
use quick_xml::{
|
||||||
|
@ -6,34 +6,42 @@ use quick_xml::{
|
||||||
Reader, Writer,
|
Reader, Writer,
|
||||||
};
|
};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncWrite, AsyncWriteExt, BufReader},
|
io::{AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
|
||||||
net::{
|
net::TcpStream,
|
||||||
tcp::{ReadHalf, WriteHalf},
|
|
||||||
TcpStream,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
use tokio_rustls::rustls;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::StreamError,
|
error::StreamError,
|
||||||
feature::Feature,
|
feature::Feature,
|
||||||
negotiator::{self, Step},
|
negotiator::{self, Step, TcpConnOrTLS},
|
||||||
tag::{self, Tag},
|
tag::{self, Tag},
|
||||||
|
tls::stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub(crate) enum StartTLSResult {
|
||||||
|
Success(stream::TLSStream),
|
||||||
|
Failure(StreamStart, StreamError),
|
||||||
|
/// TLSFailure is returned when the TLS negotiation failed, and the TCP
|
||||||
|
/// stream is to be dropped.
|
||||||
|
TLSFailure(StreamError),
|
||||||
|
}
|
||||||
|
|
||||||
type Result<T> = std::result::Result<T, StreamError>;
|
type Result<T> = std::result::Result<T, StreamError>;
|
||||||
|
|
||||||
const FEATURES: &'static [Feature] = &[Feature::start_tls(true)];
|
const FEATURES: &'static [Feature] = &[Feature::start_tls(true)];
|
||||||
|
|
||||||
struct StreamStart<'a> {
|
pub struct StreamStart {
|
||||||
reader: Reader<BufReader<ReadHalf<'a>>>,
|
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
|
||||||
writer: Writer<WriteHalf<'a>>,
|
writer: Writer<WriteHalf<TcpStream>>,
|
||||||
buffer: Vec<u8>,
|
buffer: Vec<u8>,
|
||||||
hostname: String,
|
hostname: String,
|
||||||
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> StreamStart<'a> {
|
impl StreamStart {
|
||||||
fn new(stream: &'a mut TcpStream, hostname: String) -> Self {
|
fn new(tcp_stream: TcpStream, hostname: String, tls_config: Arc<rustls::ServerConfig>) -> Self {
|
||||||
let (read, write) = stream.split();
|
let (read, write) = tokio::io::split(tcp_stream);
|
||||||
let (reader, writer) = (
|
let (reader, writer) = (
|
||||||
Reader::from_reader(BufReader::new(read)),
|
Reader::from_reader(BufReader::new(read)),
|
||||||
Writer::new(write),
|
Writer::new(write),
|
||||||
|
@ -43,85 +51,118 @@ impl<'a> StreamStart<'a> {
|
||||||
reader,
|
reader,
|
||||||
writer,
|
writer,
|
||||||
hostname,
|
hostname,
|
||||||
|
tls_config,
|
||||||
buffer: vec![],
|
buffer: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start_stream(mut self) {
|
async fn start_stream(self) {
|
||||||
match self.negotiate_stream().await {
|
match self.negotiate_stream().await {
|
||||||
Ok(_) => {}
|
StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await,
|
||||||
Err(err) => {
|
StartTLSResult::Failure(mut conn, err) => {
|
||||||
if let Err(err2) = error(self.writer.get_mut(), err).await {
|
if let Err(err2) = error(conn.writer.get_mut(), err).await {
|
||||||
error!("error writing error: {err2}");
|
error!("error writing error: {err2}");
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
info!("wrote error {err}")
|
info!("wrote error {err}")
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = self.writer.get_mut().write_all(b"</stream:stream>").await {
|
if let Err(e) = conn.writer.get_mut().write_all(b"</stream:stream>").await {
|
||||||
error!("writing end to stream: {e}")
|
error!("writing end to stream: {e}")
|
||||||
}
|
}
|
||||||
if let Err(e) = self.writer.get_mut().shutdown().await {
|
if let Err(e) = conn.writer.get_mut().shutdown().await {
|
||||||
error!("shutting down stream: {e}")
|
error!("shutting down stream: {e}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
StartTLSResult::TLSFailure(err) => {
|
||||||
|
error!("TLS negotiation failure: {err}. Dropping connection.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn negotiate_stream(&mut self) -> Result<()> {
|
async fn negotiate_stream(mut self) -> StartTLSResult {
|
||||||
let attrs = loop {
|
let attrs = loop {
|
||||||
match self.reader.read_event_into_async(&mut self.buffer).await? {
|
let event = match self.reader.read_event_into_async(&mut self.buffer).await {
|
||||||
|
Ok(event) => event,
|
||||||
|
Err(err) => return StartTLSResult::Failure(self, err.into()),
|
||||||
|
};
|
||||||
|
match event {
|
||||||
Event::Start(start) => {
|
Event::Start(start) => {
|
||||||
if start.name().as_ref() == tag::STREAM_ELEMENT_NAME {
|
if start.name().as_ref() == tag::STREAM_ELEMENT_NAME {
|
||||||
let attrs: StreamAttrs = start.attributes().try_into()?;
|
let attrs: StreamAttrs = match start.attributes().try_into() {
|
||||||
|
Ok(a) => a,
|
||||||
|
Err(err) => return StartTLSResult::Failure(self, err.into()),
|
||||||
|
};
|
||||||
if attrs.namespace != XMLNamespace::JabberClient {
|
if attrs.namespace != XMLNamespace::JabberClient {
|
||||||
return Err(StreamError::InvalidNamespace);
|
return StartTLSResult::Failure(self, StreamError::InvalidNamespace);
|
||||||
}
|
}
|
||||||
break attrs;
|
break attrs;
|
||||||
} else {
|
} else {
|
||||||
info!("element: {:?}", start);
|
info!("element: {:?}", start);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Event::End(_) => return Err(StreamError::BadFormat),
|
Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||||
Event::Eof => return Err(StreamError::BadFormat),
|
Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||||
_ => continue,
|
_ => continue,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
info!("starting negotiation with: {attrs:?}");
|
info!("starting negotiation with: {attrs:?}");
|
||||||
self.write_stream_header(StreamAttrs {
|
if let Err(err) = self
|
||||||
from: attrs.to.clone(),
|
.write_stream_header(StreamAttrs {
|
||||||
to: attrs.from,
|
from: attrs.to.clone(),
|
||||||
namespace: XMLNamespace::JabberClient,
|
to: attrs.from,
|
||||||
})
|
namespace: XMLNamespace::JabberClient,
|
||||||
.await?;
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return StartTLSResult::Failure(self, err);
|
||||||
|
};
|
||||||
if attrs.to != self.hostname {
|
if attrs.to != self.hostname {
|
||||||
return Err(StreamError::HostUnknown);
|
return StartTLSResult::Failure(self, StreamError::HostUnknown);
|
||||||
}
|
}
|
||||||
self.send_features().await?;
|
if let Err(err) = self.send_features().await {
|
||||||
|
return StartTLSResult::Failure(self, err);
|
||||||
|
};
|
||||||
loop {
|
loop {
|
||||||
match self.reader.read_event_into_async(&mut self.buffer).await? {
|
let event = match self.reader.read_event_into_async(&mut self.buffer).await {
|
||||||
|
Ok(event) => event,
|
||||||
|
Err(err) => return StartTLSResult::Failure(self, err.into()),
|
||||||
|
};
|
||||||
|
match event {
|
||||||
Event::Empty(empty) => match empty.name().as_ref() {
|
Event::Empty(empty) => match empty.name().as_ref() {
|
||||||
tag::STARTTLS => {
|
tag::STARTTLS => {
|
||||||
info!("starttls negotiation");
|
info!("starttls negotiation");
|
||||||
if let Step::Failure = negotiator::start_tls(
|
let hostname = self.hostname;
|
||||||
self.reader.get_mut(),
|
let tls_config = self.tls_config.clone();
|
||||||
self.writer.get_mut(),
|
return match negotiator::start_tls(
|
||||||
|
self.reader
|
||||||
|
.into_inner()
|
||||||
|
.into_inner()
|
||||||
|
.unsplit(self.writer.into_inner()),
|
||||||
empty,
|
empty,
|
||||||
|
self.tls_config.clone(),
|
||||||
)
|
)
|
||||||
.await?
|
.await
|
||||||
{
|
{
|
||||||
return Step::Failure.write_tag(self.writer.get_mut()).await;
|
TcpConnOrTLS::TLS(conn) => {
|
||||||
|
StartTLSResult::Success(stream::TLSStream::new(conn, hostname))
|
||||||
|
}
|
||||||
|
TcpConnOrTLS::TCP(conn, err) => StartTLSResult::Failure(
|
||||||
|
StreamStart::new(conn, hostname, tls_config),
|
||||||
|
err,
|
||||||
|
),
|
||||||
|
TcpConnOrTLS::Failed(err) => StartTLSResult::TLSFailure(err),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
_ => return Err(StreamError::UnsupportedFeature),
|
_ => return StartTLSResult::Failure(self, StreamError::UnsupportedFeature),
|
||||||
},
|
},
|
||||||
Event::End(_) => return Err(StreamError::BadFormat),
|
Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||||
Event::Eof => return Err(StreamError::BadFormat),
|
Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||||
_ => continue,
|
_ => continue,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(StreamError::InternalServerError)
|
StartTLSResult::Failure(self, StreamError::InternalServerError)
|
||||||
}
|
}
|
||||||
async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> {
|
async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> {
|
||||||
self.writer
|
self.writer
|
||||||
|
@ -156,9 +197,15 @@ impl<'a> StreamStart<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn(hostname: String, (mut stream, _): (TcpStream, SocketAddr)) {
|
pub fn spawn(
|
||||||
|
hostname: String,
|
||||||
|
(mut stream, _): (TcpStream, SocketAddr),
|
||||||
|
tls_config: Arc<rustls::ServerConfig>,
|
||||||
|
) {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
StreamStart::new(&mut stream, hostname).start_stream().await;
|
StreamStart::new(stream, hostname, tls_config)
|
||||||
|
.start_stream()
|
||||||
|
.await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
|
pub mod stream;
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
use quick_xml::{Reader, Writer};
|
||||||
|
use tokio::{
|
||||||
|
io::{BufReader, ReadHalf, WriteHalf},
|
||||||
|
net::TcpStream,
|
||||||
|
};
|
||||||
|
use tokio_rustls::server::TlsStream;
|
||||||
|
|
||||||
|
pub struct TLSStream {
|
||||||
|
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
|
||||||
|
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
|
||||||
|
buffer: Vec<u8>,
|
||||||
|
hostname: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TLSStream {
|
||||||
|
pub fn new(conn: TlsStream<TcpStream>, hostname: String) -> Self {
|
||||||
|
let (read, write) = tokio::io::split(conn);
|
||||||
|
let (reader, writer) = (
|
||||||
|
Reader::from_reader(BufReader::new(read)),
|
||||||
|
Writer::new(write),
|
||||||
|
);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
hostname,
|
||||||
|
buffer: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_stream(self) {}
|
||||||
|
}
|
Loading…
Reference in New Issue