tls/starttls

This commit is contained in:
emilis 2023-07-01 09:15:51 +01:00
parent f00fa833e0
commit e7cf44efe1
10 changed files with 363 additions and 112 deletions

5
Cargo.lock generated
View File

@ -1322,9 +1322,9 @@ dependencies = [
[[package]]
name = "rustls-pemfile"
version = "1.0.2"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
dependencies = [
"base64",
]
@ -1360,6 +1360,7 @@ dependencies = [
"quick-xml",
"rcgen",
"rsdns",
"rustls-pemfile",
"serde",
"tokio",
"tokio-rustls",

View File

@ -21,3 +21,4 @@ instant-acme = "0.3.2"
desec = { path = "../desec" }
rcgen = "0.11.1"
rsdns = { version = "0.15.0", features = ["net-tokio"] }
rustls-pemfile = "1.0.3"

View File

@ -1,21 +1,19 @@
use std::{
env,
fs::File,
io::{prelude::Write, Read},
net::{SocketAddr, ToSocketAddrs},
str::FromStr,
net::ToSocketAddrs,
os::unix::prelude::PermissionsExt,
time::Duration,
vec,
};
use desec::{
dns::{RRSet, RRSetPatch, Record},
Session,
};
use desec::dns::{RRSet, RRSetPatch, Record};
use instant_acme::{
Account, AuthorizationStatus, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
Order, OrderStatus,
};
use log::{debug, error, info, warn};
use log::{debug, error, info};
use rcgen::{Certificate, CertificateParams, DistinguishedName};
use rsdns::{
clients::{tokio::Client, ClientConfig},
@ -23,18 +21,21 @@ use rsdns::{
records::data,
};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio_rustls::rustls;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CertStore {
Provision,
Existing(CertificatePEM),
Existing(CertPaths),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub domain: String,
pub subdomain: Option<String>,
pub port: u16,
pub insecure_port: u16,
pub tls_port: u16,
pub cert_store: CertStore,
pub desec_cfg: DesecConfig,
#[serde(skip)]
@ -44,18 +45,56 @@ pub struct Config {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DesecConfig {
pub username: String,
pub password: String,
#[serde(skip)]
pub password: Option<String>,
pub name_servers: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificatePEM {
pub cert_chain_pem: String,
pub private_key_pem: String,
pub struct CertPaths {
pub cert_chain_path: String,
pub private_key_path: String,
}
impl CertPaths {
pub fn into_config(&self) -> Result<rustls::ServerConfig, anyhow::Error> {
let mut buf = std::io::BufReader::new(File::open(&self.cert_chain_path)?);
let certs = rustls_pemfile::certs(&mut buf)?
.into_iter()
.map(tokio_rustls::rustls::Certificate)
.collect();
let mut reader = std::io::BufReader::new(File::open(&self.private_key_path)?);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
let key = match keys.len() {
0 => {
return Err(anyhow::anyhow!(
"No PKCS8-encoded private key found in {}",
&self.private_key_path,
))
}
1 => rustls::PrivateKey(keys.remove(0)),
_ => {
return Err(anyhow::anyhow!(
"More than one PKCS8-encoded private key found in {}",
&self.private_key_path,
))
}
};
Ok(rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?)
}
}
const CONFIG_PATHS: [&str; 3] = [DEFAULT_PATH, "/etc/salut.toml", "/usr/local/etc/salut.toml"];
const CERT_PATHS: [&str; 3] = ["/etc/salut/", "/usr/local/etc/salut/", "./"];
const CERT_FILENAME: &str = "salut.cert";
const CERT_PRIVKEY_FILENAME: &str = "salut.pk";
const DNS_QUERY_WAIT: Duration = Duration::from_millis(250);
const DNS_PASSWORD_VAR: &str = "DESEC_PASSWORD";
pub const DEFAULT_PATH: &str = "salut.toml";
impl Default for Config {
@ -63,11 +102,12 @@ impl Default for Config {
Self {
domain: String::new(),
subdomain: Some(String::new()),
port: 5222,
insecure_port: 5222,
tls_port: 5223,
cert_store: CertStore::Provision,
desec_cfg: DesecConfig {
username: String::new(),
password: String::new(),
password: None,
name_servers: vec!["ns1.desec.io".into(), "ns2.desec.org".into()],
},
original_path: DEFAULT_PATH,
@ -83,6 +123,9 @@ impl Config {
file.read_to_string(&mut cfg)?;
let mut cfg: Self = toml::from_str(&cfg)?;
cfg.original_path = path;
if let Ok(pass) = env::var(DNS_PASSWORD_VAR) {
cfg.desec_cfg.password = Some(pass);
}
return Ok(cfg);
};
}
@ -111,7 +154,7 @@ const ACME_PREFIX: &str = "_acme-challenge";
impl Config {
// Returns existing certificate or provisions a new one via DNS challenge using DeSEC
pub async fn certificate(&self) -> Result<CertificatePEM, anyhow::Error> {
pub async fn certificate(&self) -> Result<CertPaths, anyhow::Error> {
let desec_cfg = match self.cert_store.clone() {
CertStore::Provision => self.desec_cfg.clone(),
CertStore::Existing(existing) => return Ok(existing),
@ -141,7 +184,11 @@ impl Config {
assert!(matches!(state.status, OrderStatus::Pending));
debug!("logging into desec as <{}>", &desec_cfg.username);
let dns = desec::Session::login(&desec_cfg.username, &desec_cfg.password).await?;
let desec_pass = match desec_cfg.password {
Some(pass) => pass,
None => panic!("need desec password for provisioning the TLS certificate, please set {DNS_PASSWORD_VAR}"),
};
let dns = desec::Session::login(&desec_cfg.username, &desec_pass).await?;
debug!("querying existing TXT records");
let existing_records: Vec<RRSet> = dns
.get_rrsets(&self.domain, Some(vec![Record::TXT]))
@ -258,16 +305,12 @@ impl Config {
}
};
let cert = CertificatePEM {
cert_chain_pem,
private_key_pem: cert.serialize_private_key_pem(),
};
let mut new_cfg = self.clone();
new_cfg.cert_store = CertStore::Existing(cert.clone());
let paths = save_certs_pam(cert_chain_pem, cert.serialize_private_key_pem()).await?;
new_cfg.cert_store = CertStore::Existing(paths.clone());
new_cfg.save(self.original_path)?;
Ok(cert)
Ok(paths)
}
async fn wait_challenges(
@ -362,3 +405,36 @@ impl Config {
}
}
}
async fn save_certs_pam(
cert_chain: String,
private_key: String,
) -> Result<CertPaths, anyhow::Error> {
for try_path in CERT_PATHS {
if let Err(_) = std::fs::create_dir_all(try_path) {
continue;
}
let cert_chain_path = format!("{try_path}{CERT_FILENAME}");
let mut cert_file = match tokio::fs::File::create(&cert_chain_path).await {
Ok(f) => f,
Err(_) => continue,
};
let private_key_path = format!("{try_path}{CERT_PRIVKEY_FILENAME}");
let mut pk_file = tokio::fs::File::create(&private_key_path).await?;
pk_file.metadata().await?.permissions().set_mode(600);
cert_file.metadata().await?.permissions().set_mode(600);
cert_file.write_all(cert_chain.as_bytes()).await?;
pk_file.write_all(private_key.as_bytes()).await?;
cert_file.flush().await?;
pk_file.flush().await?;
return Ok(CertPaths {
cert_chain_path,
private_key_path,
});
}
Err(anyhow::anyhow!(
"could not create/save cert files in {CERT_PATHS:?}"
))
}

View File

@ -1,5 +1,6 @@
use enum_display::EnumDisplay;
use std::string::FromUtf8Error;
use tokio_rustls::rustls;
use log::error;
use quick_xml::events::attributes::AttrError;
@ -126,6 +127,37 @@ pub enum StreamError {
/// the server.
UnsupportedVersion,
}
impl From<rustls::Error> for StreamError {
fn from(value: rustls::Error) -> Self {
match value {
rustls::Error::InappropriateMessage {
expect_types,
got_type,
} => Self::BadFormat,
rustls::Error::InappropriateHandshakeMessage {
expect_types,
got_type,
} => Self::BadFormat,
rustls::Error::InvalidMessage(_) => Self::BadFormat,
rustls::Error::NoCertificatesPresented => Self::InternalServerError,
rustls::Error::UnsupportedNameType => Self::BadFormat,
rustls::Error::DecryptError => Self::InternalServerError,
rustls::Error::EncryptError => Self::InternalServerError,
rustls::Error::PeerIncompatible(_) => Self::BadFormat,
rustls::Error::PeerMisbehaved(_) => Self::BadFormat,
rustls::Error::AlertReceived(_) => Self::BadFormat,
rustls::Error::InvalidCertificate(_) => Self::InternalServerError,
_ => Self::BadFormat,
}
}
}
impl From<std::io::Error> for StreamError {
fn from(value: std::io::Error) -> Self {
error!("io error: {value}");
Self::InternalServerError
}
}
impl From<FromUtf8Error> for StreamError {
fn from(_: FromUtf8Error) -> Self {

View File

@ -1,4 +1,4 @@
use std::process;
use std::{process, sync::Arc};
use log::{error, info};
@ -33,10 +33,27 @@ async fn main() -> Result<(), anyhow::Error> {
}
};
info!("checking for certificates");
let certs = cfg.certificate().await.expect("getting certificates");
let certs = Arc::new(
cfg.certificate()
.await
.expect("getting certificates")
.into_config()?,
);
let host = cfg.hostname();
info!("listening on {host}:{}!", cfg.port);
server::listen(host, cfg.port).await.unwrap();
info!(
"listening on {host} {} (plain) and {} (tls)!",
cfg.insecure_port, cfg.tls_port
);
let cert_clone = certs.clone();
let host_clone = host.clone();
tokio::spawn(async move {
server::listen_tls(host_clone, cfg.tls_port, cert_clone)
.await
.expect("TLS listener")
});
server::listen_starttls(host, cfg.insecure_port, certs)
.await
.expect("STARTTLS listener");
Ok(())
}

View File

@ -1,14 +1,18 @@
use std::sync::Arc;
use async_trait::async_trait;
use quick_xml::{
events::{BytesStart, Event},
Writer,
};
use tokio::io::{AsyncBufRead, AsyncWrite};
use tokio_rustls::rustls;
use tokio::{io::AsyncWrite, net::TcpStream};
use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
use crate::{
error::StreamError,
streamstart::{StartTLSResult, StreamStart},
tag::{self, Tag},
tls::stream::{self, TLSStream},
};
pub enum Step {
@ -36,15 +40,20 @@ impl Tag for Step {
}
}
pub async fn start_tls<R, W>(
reader: R,
writer: W,
pub(crate) enum TcpConnOrTLS {
TCP(TcpStream, StreamError),
/// Failed is returned when there's a failure starting the TLS stream
/// thus the TCP stream is consumed.
Failed(StreamError),
TLS(TlsStream<TcpStream>),
}
pub(crate) async fn start_tls(
conn: TcpStream,
start_tls_event: BytesStart<'_>,
) -> Result<Step, StreamError>
where
R: AsyncBufRead + Unpin,
W: AsyncWrite + Unpin + Send,
{
tls_config: Arc<rustls::ServerConfig>,
) -> TcpConnOrTLS {
let mut conn = conn;
match start_tls_event.try_get_attribute("xmlns") {
Ok(namespace) => {
if &namespace
@ -52,28 +61,26 @@ where
.unwrap_or_default()
!= tag::TLS_NAMESPACE
{
return Ok(Step::Failure);
if let Err(err) = Step::Failure.write_tag(&mut conn).await {
return TcpConnOrTLS::TCP(conn, err);
}
return TcpConnOrTLS::TCP(conn, StreamError::InvalidNamespace);
}
}
Err(_) => return Ok(Step::Failure),
Err(_) => {
if let Err(err) = Step::Failure.write_tag(&mut conn).await {
return TcpConnOrTLS::TCP(conn, err);
}
return TcpConnOrTLS::TCP(conn, StreamError::BadFormat);
}
}
// let config = rustls::ServerConfig::builder()
// .with_safe_defaults()
// .with_no_client_auth()
// .with_single_cert(certs, keys.remove(0))
// .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
Step::Proceed.write_tag(writer).await?;
// match TlsConnector::builder(). {
// Ok(conn) => conn.,
// Err(err) => {
// error!("getting a tls connector: {err}");
// return Ok(Step::Failure);
// }
// }
std::thread::sleep(std::time::Duration::from_secs(3));
todo!()
if let Err(err) = Step::Proceed.write_tag(&mut conn).await {
return TcpConnOrTLS::TCP(conn, err);
};
let acceptor = TlsAcceptor::from(tls_config);
match acceptor.accept(conn).await {
Ok(tls_conn) => TcpConnOrTLS::TLS(tls_conn),
Err(err) => TcpConnOrTLS::Failed(err.into()),
}
}

View File

@ -1,15 +1,53 @@
use log::info;
use std::sync::Arc;
use log::{error, info};
use tokio::net::TcpListener;
use tokio_rustls::{rustls, TlsAcceptor};
use crate::streamstart;
use crate::{streamstart, tls::stream};
pub async fn listen(hostname: String, port: u16) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind(("0.0.0.0", port)).await?;
pub async fn listen_starttls(
hostname: String,
insecure_port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind(("0.0.0.0", insecure_port)).await?;
loop {
match listener.accept().await {
Ok(conn) => {
info!("opening connection from {}", conn.1);
streamstart::spawn(hostname.clone(), conn);
info!("opening starttls connection from {}", conn.1);
streamstart::spawn(hostname.clone(), conn, tls_config.clone());
}
Err(e) => {
eprintln!("listening: {e}");
continue;
}
}
}
}
pub async fn listen_tls(
hostname: String,
tls_port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind(("0.0.0.0", tls_port)).await?;
let acceptor = TlsAcceptor::from(tls_config);
loop {
match listener.accept().await {
Ok(conn) => {
info!("opening TLS connection from {}", conn.1);
match acceptor.accept(conn.0).await {
Ok(conn) => {
let cloned_hostname = hostname.clone();
tokio::spawn(async move {
stream::TLSStream::new(conn, cloned_hostname)
.start_stream()
.await
});
}
Err(err) => error!("TLS accept error for {}: {err}", conn.1),
}
}
Err(e) => {
eprintln!("listening: {e}");

View File

@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, sync::Arc};
use log::{error, info};
use quick_xml::{
@ -6,34 +6,42 @@ use quick_xml::{
Reader, Writer,
};
use tokio::{
io::{AsyncWrite, AsyncWriteExt, BufReader},
net::{
tcp::{ReadHalf, WriteHalf},
TcpStream,
},
io::{AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
net::TcpStream,
};
use tokio_rustls::rustls;
use crate::{
error::StreamError,
feature::Feature,
negotiator::{self, Step},
negotiator::{self, Step, TcpConnOrTLS},
tag::{self, Tag},
tls::stream,
};
pub(crate) enum StartTLSResult {
Success(stream::TLSStream),
Failure(StreamStart, StreamError),
/// TLSFailure is returned when the TLS negotiation failed, and the TCP
/// stream is to be dropped.
TLSFailure(StreamError),
}
type Result<T> = std::result::Result<T, StreamError>;
const FEATURES: &'static [Feature] = &[Feature::start_tls(true)];
struct StreamStart<'a> {
reader: Reader<BufReader<ReadHalf<'a>>>,
writer: Writer<WriteHalf<'a>>,
pub struct StreamStart {
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
writer: Writer<WriteHalf<TcpStream>>,
buffer: Vec<u8>,
hostname: String,
tls_config: Arc<rustls::ServerConfig>,
}
impl<'a> StreamStart<'a> {
fn new(stream: &'a mut TcpStream, hostname: String) -> Self {
let (read, write) = stream.split();
impl StreamStart {
fn new(tcp_stream: TcpStream, hostname: String, tls_config: Arc<rustls::ServerConfig>) -> Self {
let (read, write) = tokio::io::split(tcp_stream);
let (reader, writer) = (
Reader::from_reader(BufReader::new(read)),
Writer::new(write),
@ -43,85 +51,118 @@ impl<'a> StreamStart<'a> {
reader,
writer,
hostname,
tls_config,
buffer: vec![],
}
}
async fn start_stream(mut self) {
async fn start_stream(self) {
match self.negotiate_stream().await {
Ok(_) => {}
Err(err) => {
if let Err(err2) = error(self.writer.get_mut(), err).await {
StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await,
StartTLSResult::Failure(mut conn, err) => {
if let Err(err2) = error(conn.writer.get_mut(), err).await {
error!("error writing error: {err2}");
return;
} else {
info!("wrote error {err}")
}
if let Err(e) = self.writer.get_mut().write_all(b"</stream:stream>").await {
if let Err(e) = conn.writer.get_mut().write_all(b"</stream:stream>").await {
error!("writing end to stream: {e}")
}
if let Err(e) = self.writer.get_mut().shutdown().await {
if let Err(e) = conn.writer.get_mut().shutdown().await {
error!("shutting down stream: {e}")
}
}
StartTLSResult::TLSFailure(err) => {
error!("TLS negotiation failure: {err}. Dropping connection.")
}
}
}
async fn negotiate_stream(&mut self) -> Result<()> {
async fn negotiate_stream(mut self) -> StartTLSResult {
let attrs = loop {
match self.reader.read_event_into_async(&mut self.buffer).await? {
let event = match self.reader.read_event_into_async(&mut self.buffer).await {
Ok(event) => event,
Err(err) => return StartTLSResult::Failure(self, err.into()),
};
match event {
Event::Start(start) => {
if start.name().as_ref() == tag::STREAM_ELEMENT_NAME {
let attrs: StreamAttrs = start.attributes().try_into()?;
let attrs: StreamAttrs = match start.attributes().try_into() {
Ok(a) => a,
Err(err) => return StartTLSResult::Failure(self, err.into()),
};
if attrs.namespace != XMLNamespace::JabberClient {
return Err(StreamError::InvalidNamespace);
return StartTLSResult::Failure(self, StreamError::InvalidNamespace);
}
break attrs;
} else {
info!("element: {:?}", start);
}
}
Event::End(_) => return Err(StreamError::BadFormat),
Event::Eof => return Err(StreamError::BadFormat),
Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
_ => continue,
}
};
info!("starting negotiation with: {attrs:?}");
self.write_stream_header(StreamAttrs {
from: attrs.to.clone(),
to: attrs.from,
namespace: XMLNamespace::JabberClient,
})
.await?;
if let Err(err) = self
.write_stream_header(StreamAttrs {
from: attrs.to.clone(),
to: attrs.from,
namespace: XMLNamespace::JabberClient,
})
.await
{
return StartTLSResult::Failure(self, err);
};
if attrs.to != self.hostname {
return Err(StreamError::HostUnknown);
return StartTLSResult::Failure(self, StreamError::HostUnknown);
}
self.send_features().await?;
if let Err(err) = self.send_features().await {
return StartTLSResult::Failure(self, err);
};
loop {
match self.reader.read_event_into_async(&mut self.buffer).await? {
let event = match self.reader.read_event_into_async(&mut self.buffer).await {
Ok(event) => event,
Err(err) => return StartTLSResult::Failure(self, err.into()),
};
match event {
Event::Empty(empty) => match empty.name().as_ref() {
tag::STARTTLS => {
info!("starttls negotiation");
if let Step::Failure = negotiator::start_tls(
self.reader.get_mut(),
self.writer.get_mut(),
let hostname = self.hostname;
let tls_config = self.tls_config.clone();
return match negotiator::start_tls(
self.reader
.into_inner()
.into_inner()
.unsplit(self.writer.into_inner()),
empty,
self.tls_config.clone(),
)
.await?
.await
{
return Step::Failure.write_tag(self.writer.get_mut()).await;
TcpConnOrTLS::TLS(conn) => {
StartTLSResult::Success(stream::TLSStream::new(conn, hostname))
}
TcpConnOrTLS::TCP(conn, err) => StartTLSResult::Failure(
StreamStart::new(conn, hostname, tls_config),
err,
),
TcpConnOrTLS::Failed(err) => StartTLSResult::TLSFailure(err),
};
}
_ => return Err(StreamError::UnsupportedFeature),
_ => return StartTLSResult::Failure(self, StreamError::UnsupportedFeature),
},
Event::End(_) => return Err(StreamError::BadFormat),
Event::Eof => return Err(StreamError::BadFormat),
Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
_ => continue,
}
}
Err(StreamError::InternalServerError)
StartTLSResult::Failure(self, StreamError::InternalServerError)
}
async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> {
self.writer
@ -156,9 +197,15 @@ impl<'a> StreamStart<'a> {
}
}
pub fn spawn(hostname: String, (mut stream, _): (TcpStream, SocketAddr)) {
pub fn spawn(
hostname: String,
(mut stream, _): (TcpStream, SocketAddr),
tls_config: Arc<rustls::ServerConfig>,
) {
tokio::spawn(async move {
StreamStart::new(&mut stream, hostname).start_stream().await;
StreamStart::new(stream, hostname, tls_config)
.start_stream()
.await;
});
}

View File

@ -1 +1 @@
pub mod stream;

32
salut/src/tls/stream.rs Normal file
View File

@ -0,0 +1,32 @@
use quick_xml::{Reader, Writer};
use tokio::{
io::{BufReader, ReadHalf, WriteHalf},
net::TcpStream,
};
use tokio_rustls::server::TlsStream;
pub struct TLSStream {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
buffer: Vec<u8>,
hostname: String,
}
impl TLSStream {
pub fn new(conn: TlsStream<TcpStream>, hostname: String) -> Self {
let (read, write) = tokio::io::split(conn);
let (reader, writer) = (
Reader::from_reader(BufReader::new(read)),
Writer::new(write),
);
Self {
reader,
writer,
hostname,
buffer: vec![],
}
}
pub async fn start_stream(self) {}
}