tls/starttls

This commit is contained in:
emilis 2023-07-01 09:15:51 +01:00
parent f00fa833e0
commit e7cf44efe1
10 changed files with 363 additions and 112 deletions

5
Cargo.lock generated
View File

@ -1322,9 +1322,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls-pemfile" name = "rustls-pemfile"
version = "1.0.2" version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
dependencies = [ dependencies = [
"base64", "base64",
] ]
@ -1360,6 +1360,7 @@ dependencies = [
"quick-xml", "quick-xml",
"rcgen", "rcgen",
"rsdns", "rsdns",
"rustls-pemfile",
"serde", "serde",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",

View File

@ -21,3 +21,4 @@ instant-acme = "0.3.2"
desec = { path = "../desec" } desec = { path = "../desec" }
rcgen = "0.11.1" rcgen = "0.11.1"
rsdns = { version = "0.15.0", features = ["net-tokio"] } rsdns = { version = "0.15.0", features = ["net-tokio"] }
rustls-pemfile = "1.0.3"

View File

@ -1,21 +1,19 @@
use std::{ use std::{
env,
fs::File, fs::File,
io::{prelude::Write, Read}, io::{prelude::Write, Read},
net::{SocketAddr, ToSocketAddrs}, net::ToSocketAddrs,
str::FromStr, os::unix::prelude::PermissionsExt,
time::Duration, time::Duration,
vec, vec,
}; };
use desec::{ use desec::dns::{RRSet, RRSetPatch, Record};
dns::{RRSet, RRSetPatch, Record},
Session,
};
use instant_acme::{ use instant_acme::{
Account, AuthorizationStatus, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder, Account, AuthorizationStatus, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
Order, OrderStatus, Order, OrderStatus,
}; };
use log::{debug, error, info, warn}; use log::{debug, error, info};
use rcgen::{Certificate, CertificateParams, DistinguishedName}; use rcgen::{Certificate, CertificateParams, DistinguishedName};
use rsdns::{ use rsdns::{
clients::{tokio::Client, ClientConfig}, clients::{tokio::Client, ClientConfig},
@ -23,18 +21,21 @@ use rsdns::{
records::data, records::data,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio_rustls::rustls;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CertStore { pub enum CertStore {
Provision, Provision,
Existing(CertificatePEM), Existing(CertPaths),
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config { pub struct Config {
pub domain: String, pub domain: String,
pub subdomain: Option<String>, pub subdomain: Option<String>,
pub port: u16, pub insecure_port: u16,
pub tls_port: u16,
pub cert_store: CertStore, pub cert_store: CertStore,
pub desec_cfg: DesecConfig, pub desec_cfg: DesecConfig,
#[serde(skip)] #[serde(skip)]
@ -44,18 +45,56 @@ pub struct Config {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DesecConfig { pub struct DesecConfig {
pub username: String, pub username: String,
pub password: String, #[serde(skip)]
pub password: Option<String>,
pub name_servers: Vec<String>, pub name_servers: Vec<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificatePEM { pub struct CertPaths {
pub cert_chain_pem: String, pub cert_chain_path: String,
pub private_key_pem: String, pub private_key_path: String,
}
impl CertPaths {
pub fn into_config(&self) -> Result<rustls::ServerConfig, anyhow::Error> {
let mut buf = std::io::BufReader::new(File::open(&self.cert_chain_path)?);
let certs = rustls_pemfile::certs(&mut buf)?
.into_iter()
.map(tokio_rustls::rustls::Certificate)
.collect();
let mut reader = std::io::BufReader::new(File::open(&self.private_key_path)?);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
let key = match keys.len() {
0 => {
return Err(anyhow::anyhow!(
"No PKCS8-encoded private key found in {}",
&self.private_key_path,
))
}
1 => rustls::PrivateKey(keys.remove(0)),
_ => {
return Err(anyhow::anyhow!(
"More than one PKCS8-encoded private key found in {}",
&self.private_key_path,
))
}
};
Ok(rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?)
}
} }
const CONFIG_PATHS: [&str; 3] = [DEFAULT_PATH, "/etc/salut.toml", "/usr/local/etc/salut.toml"]; const CONFIG_PATHS: [&str; 3] = [DEFAULT_PATH, "/etc/salut.toml", "/usr/local/etc/salut.toml"];
const CERT_PATHS: [&str; 3] = ["/etc/salut/", "/usr/local/etc/salut/", "./"];
const CERT_FILENAME: &str = "salut.cert";
const CERT_PRIVKEY_FILENAME: &str = "salut.pk";
const DNS_QUERY_WAIT: Duration = Duration::from_millis(250); const DNS_QUERY_WAIT: Duration = Duration::from_millis(250);
const DNS_PASSWORD_VAR: &str = "DESEC_PASSWORD";
pub const DEFAULT_PATH: &str = "salut.toml"; pub const DEFAULT_PATH: &str = "salut.toml";
impl Default for Config { impl Default for Config {
@ -63,11 +102,12 @@ impl Default for Config {
Self { Self {
domain: String::new(), domain: String::new(),
subdomain: Some(String::new()), subdomain: Some(String::new()),
port: 5222, insecure_port: 5222,
tls_port: 5223,
cert_store: CertStore::Provision, cert_store: CertStore::Provision,
desec_cfg: DesecConfig { desec_cfg: DesecConfig {
username: String::new(), username: String::new(),
password: String::new(), password: None,
name_servers: vec!["ns1.desec.io".into(), "ns2.desec.org".into()], name_servers: vec!["ns1.desec.io".into(), "ns2.desec.org".into()],
}, },
original_path: DEFAULT_PATH, original_path: DEFAULT_PATH,
@ -83,6 +123,9 @@ impl Config {
file.read_to_string(&mut cfg)?; file.read_to_string(&mut cfg)?;
let mut cfg: Self = toml::from_str(&cfg)?; let mut cfg: Self = toml::from_str(&cfg)?;
cfg.original_path = path; cfg.original_path = path;
if let Ok(pass) = env::var(DNS_PASSWORD_VAR) {
cfg.desec_cfg.password = Some(pass);
}
return Ok(cfg); return Ok(cfg);
}; };
} }
@ -111,7 +154,7 @@ const ACME_PREFIX: &str = "_acme-challenge";
impl Config { impl Config {
// Returns existing certificate or provisions a new one via DNS challenge using DeSEC // Returns existing certificate or provisions a new one via DNS challenge using DeSEC
pub async fn certificate(&self) -> Result<CertificatePEM, anyhow::Error> { pub async fn certificate(&self) -> Result<CertPaths, anyhow::Error> {
let desec_cfg = match self.cert_store.clone() { let desec_cfg = match self.cert_store.clone() {
CertStore::Provision => self.desec_cfg.clone(), CertStore::Provision => self.desec_cfg.clone(),
CertStore::Existing(existing) => return Ok(existing), CertStore::Existing(existing) => return Ok(existing),
@ -141,7 +184,11 @@ impl Config {
assert!(matches!(state.status, OrderStatus::Pending)); assert!(matches!(state.status, OrderStatus::Pending));
debug!("logging into desec as <{}>", &desec_cfg.username); debug!("logging into desec as <{}>", &desec_cfg.username);
let dns = desec::Session::login(&desec_cfg.username, &desec_cfg.password).await?; let desec_pass = match desec_cfg.password {
Some(pass) => pass,
None => panic!("need desec password for provisioning the TLS certificate, please set {DNS_PASSWORD_VAR}"),
};
let dns = desec::Session::login(&desec_cfg.username, &desec_pass).await?;
debug!("querying existing TXT records"); debug!("querying existing TXT records");
let existing_records: Vec<RRSet> = dns let existing_records: Vec<RRSet> = dns
.get_rrsets(&self.domain, Some(vec![Record::TXT])) .get_rrsets(&self.domain, Some(vec![Record::TXT]))
@ -258,16 +305,12 @@ impl Config {
} }
}; };
let cert = CertificatePEM {
cert_chain_pem,
private_key_pem: cert.serialize_private_key_pem(),
};
let mut new_cfg = self.clone(); let mut new_cfg = self.clone();
new_cfg.cert_store = CertStore::Existing(cert.clone()); let paths = save_certs_pam(cert_chain_pem, cert.serialize_private_key_pem()).await?;
new_cfg.cert_store = CertStore::Existing(paths.clone());
new_cfg.save(self.original_path)?; new_cfg.save(self.original_path)?;
Ok(cert) Ok(paths)
} }
async fn wait_challenges( async fn wait_challenges(
@ -362,3 +405,36 @@ impl Config {
} }
} }
} }
async fn save_certs_pam(
cert_chain: String,
private_key: String,
) -> Result<CertPaths, anyhow::Error> {
for try_path in CERT_PATHS {
if let Err(_) = std::fs::create_dir_all(try_path) {
continue;
}
let cert_chain_path = format!("{try_path}{CERT_FILENAME}");
let mut cert_file = match tokio::fs::File::create(&cert_chain_path).await {
Ok(f) => f,
Err(_) => continue,
};
let private_key_path = format!("{try_path}{CERT_PRIVKEY_FILENAME}");
let mut pk_file = tokio::fs::File::create(&private_key_path).await?;
pk_file.metadata().await?.permissions().set_mode(600);
cert_file.metadata().await?.permissions().set_mode(600);
cert_file.write_all(cert_chain.as_bytes()).await?;
pk_file.write_all(private_key.as_bytes()).await?;
cert_file.flush().await?;
pk_file.flush().await?;
return Ok(CertPaths {
cert_chain_path,
private_key_path,
});
}
Err(anyhow::anyhow!(
"could not create/save cert files in {CERT_PATHS:?}"
))
}

View File

@ -1,5 +1,6 @@
use enum_display::EnumDisplay; use enum_display::EnumDisplay;
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use tokio_rustls::rustls;
use log::error; use log::error;
use quick_xml::events::attributes::AttrError; use quick_xml::events::attributes::AttrError;
@ -126,6 +127,37 @@ pub enum StreamError {
/// the server. /// the server.
UnsupportedVersion, UnsupportedVersion,
} }
impl From<rustls::Error> for StreamError {
fn from(value: rustls::Error) -> Self {
match value {
rustls::Error::InappropriateMessage {
expect_types,
got_type,
} => Self::BadFormat,
rustls::Error::InappropriateHandshakeMessage {
expect_types,
got_type,
} => Self::BadFormat,
rustls::Error::InvalidMessage(_) => Self::BadFormat,
rustls::Error::NoCertificatesPresented => Self::InternalServerError,
rustls::Error::UnsupportedNameType => Self::BadFormat,
rustls::Error::DecryptError => Self::InternalServerError,
rustls::Error::EncryptError => Self::InternalServerError,
rustls::Error::PeerIncompatible(_) => Self::BadFormat,
rustls::Error::PeerMisbehaved(_) => Self::BadFormat,
rustls::Error::AlertReceived(_) => Self::BadFormat,
rustls::Error::InvalidCertificate(_) => Self::InternalServerError,
_ => Self::BadFormat,
}
}
}
impl From<std::io::Error> for StreamError {
fn from(value: std::io::Error) -> Self {
error!("io error: {value}");
Self::InternalServerError
}
}
impl From<FromUtf8Error> for StreamError { impl From<FromUtf8Error> for StreamError {
fn from(_: FromUtf8Error) -> Self { fn from(_: FromUtf8Error) -> Self {

View File

@ -1,4 +1,4 @@
use std::process; use std::{process, sync::Arc};
use log::{error, info}; use log::{error, info};
@ -33,10 +33,27 @@ async fn main() -> Result<(), anyhow::Error> {
} }
}; };
info!("checking for certificates"); info!("checking for certificates");
let certs = cfg.certificate().await.expect("getting certificates"); let certs = Arc::new(
cfg.certificate()
.await
.expect("getting certificates")
.into_config()?,
);
let host = cfg.hostname(); let host = cfg.hostname();
info!("listening on {host}:{}!", cfg.port); info!(
server::listen(host, cfg.port).await.unwrap(); "listening on {host} {} (plain) and {} (tls)!",
cfg.insecure_port, cfg.tls_port
);
let cert_clone = certs.clone();
let host_clone = host.clone();
tokio::spawn(async move {
server::listen_tls(host_clone, cfg.tls_port, cert_clone)
.await
.expect("TLS listener")
});
server::listen_starttls(host, cfg.insecure_port, certs)
.await
.expect("STARTTLS listener");
Ok(()) Ok(())
} }

View File

@ -1,14 +1,18 @@
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use quick_xml::{ use quick_xml::{
events::{BytesStart, Event}, events::{BytesStart, Event},
Writer, Writer,
}; };
use tokio::io::{AsyncBufRead, AsyncWrite}; use tokio::{io::AsyncWrite, net::TcpStream};
use tokio_rustls::rustls; use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
use crate::{ use crate::{
error::StreamError, error::StreamError,
streamstart::{StartTLSResult, StreamStart},
tag::{self, Tag}, tag::{self, Tag},
tls::stream::{self, TLSStream},
}; };
pub enum Step { pub enum Step {
@ -36,15 +40,20 @@ impl Tag for Step {
} }
} }
pub async fn start_tls<R, W>( pub(crate) enum TcpConnOrTLS {
reader: R, TCP(TcpStream, StreamError),
writer: W, /// Failed is returned when there's a failure starting the TLS stream
/// thus the TCP stream is consumed.
Failed(StreamError),
TLS(TlsStream<TcpStream>),
}
pub(crate) async fn start_tls(
conn: TcpStream,
start_tls_event: BytesStart<'_>, start_tls_event: BytesStart<'_>,
) -> Result<Step, StreamError> tls_config: Arc<rustls::ServerConfig>,
where ) -> TcpConnOrTLS {
R: AsyncBufRead + Unpin, let mut conn = conn;
W: AsyncWrite + Unpin + Send,
{
match start_tls_event.try_get_attribute("xmlns") { match start_tls_event.try_get_attribute("xmlns") {
Ok(namespace) => { Ok(namespace) => {
if &namespace if &namespace
@ -52,28 +61,26 @@ where
.unwrap_or_default() .unwrap_or_default()
!= tag::TLS_NAMESPACE != tag::TLS_NAMESPACE
{ {
return Ok(Step::Failure); if let Err(err) = Step::Failure.write_tag(&mut conn).await {
return TcpConnOrTLS::TCP(conn, err);
}
return TcpConnOrTLS::TCP(conn, StreamError::InvalidNamespace);
} }
} }
Err(_) => return Ok(Step::Failure), Err(_) => {
if let Err(err) = Step::Failure.write_tag(&mut conn).await {
return TcpConnOrTLS::TCP(conn, err);
}
return TcpConnOrTLS::TCP(conn, StreamError::BadFormat);
}
} }
// let config = rustls::ServerConfig::builder() if let Err(err) = Step::Proceed.write_tag(&mut conn).await {
// .with_safe_defaults() return TcpConnOrTLS::TCP(conn, err);
// .with_no_client_auth() };
// .with_single_cert(certs, keys.remove(0)) let acceptor = TlsAcceptor::from(tls_config);
// .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; match acceptor.accept(conn).await {
Ok(tls_conn) => TcpConnOrTLS::TLS(tls_conn),
Step::Proceed.write_tag(writer).await?; Err(err) => TcpConnOrTLS::Failed(err.into()),
}
// match TlsConnector::builder(). {
// Ok(conn) => conn.,
// Err(err) => {
// error!("getting a tls connector: {err}");
// return Ok(Step::Failure);
// }
// }
std::thread::sleep(std::time::Duration::from_secs(3));
todo!()
} }

View File

@ -1,15 +1,53 @@
use log::info; use std::sync::Arc;
use log::{error, info};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_rustls::{rustls, TlsAcceptor};
use crate::streamstart; use crate::{streamstart, tls::stream};
pub async fn listen(hostname: String, port: u16) -> Result<(), anyhow::Error> { pub async fn listen_starttls(
let listener = TcpListener::bind(("0.0.0.0", port)).await?; hostname: String,
insecure_port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind(("0.0.0.0", insecure_port)).await?;
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok(conn) => { Ok(conn) => {
info!("opening connection from {}", conn.1); info!("opening starttls connection from {}", conn.1);
streamstart::spawn(hostname.clone(), conn); streamstart::spawn(hostname.clone(), conn, tls_config.clone());
}
Err(e) => {
eprintln!("listening: {e}");
continue;
}
}
}
}
pub async fn listen_tls(
hostname: String,
tls_port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind(("0.0.0.0", tls_port)).await?;
let acceptor = TlsAcceptor::from(tls_config);
loop {
match listener.accept().await {
Ok(conn) => {
info!("opening TLS connection from {}", conn.1);
match acceptor.accept(conn.0).await {
Ok(conn) => {
let cloned_hostname = hostname.clone();
tokio::spawn(async move {
stream::TLSStream::new(conn, cloned_hostname)
.start_stream()
.await
});
}
Err(err) => error!("TLS accept error for {}: {err}", conn.1),
}
} }
Err(e) => { Err(e) => {
eprintln!("listening: {e}"); eprintln!("listening: {e}");

View File

@ -1,4 +1,4 @@
use std::net::SocketAddr; use std::{net::SocketAddr, sync::Arc};
use log::{error, info}; use log::{error, info};
use quick_xml::{ use quick_xml::{
@ -6,34 +6,42 @@ use quick_xml::{
Reader, Writer, Reader, Writer,
}; };
use tokio::{ use tokio::{
io::{AsyncWrite, AsyncWriteExt, BufReader}, io::{AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
net::{ net::TcpStream,
tcp::{ReadHalf, WriteHalf},
TcpStream,
},
}; };
use tokio_rustls::rustls;
use crate::{ use crate::{
error::StreamError, error::StreamError,
feature::Feature, feature::Feature,
negotiator::{self, Step}, negotiator::{self, Step, TcpConnOrTLS},
tag::{self, Tag}, tag::{self, Tag},
tls::stream,
}; };
pub(crate) enum StartTLSResult {
Success(stream::TLSStream),
Failure(StreamStart, StreamError),
/// TLSFailure is returned when the TLS negotiation failed, and the TCP
/// stream is to be dropped.
TLSFailure(StreamError),
}
type Result<T> = std::result::Result<T, StreamError>; type Result<T> = std::result::Result<T, StreamError>;
const FEATURES: &'static [Feature] = &[Feature::start_tls(true)]; const FEATURES: &'static [Feature] = &[Feature::start_tls(true)];
struct StreamStart<'a> { pub struct StreamStart {
reader: Reader<BufReader<ReadHalf<'a>>>, reader: Reader<BufReader<ReadHalf<TcpStream>>>,
writer: Writer<WriteHalf<'a>>, writer: Writer<WriteHalf<TcpStream>>,
buffer: Vec<u8>, buffer: Vec<u8>,
hostname: String, hostname: String,
tls_config: Arc<rustls::ServerConfig>,
} }
impl<'a> StreamStart<'a> { impl StreamStart {
fn new(stream: &'a mut TcpStream, hostname: String) -> Self { fn new(tcp_stream: TcpStream, hostname: String, tls_config: Arc<rustls::ServerConfig>) -> Self {
let (read, write) = stream.split(); let (read, write) = tokio::io::split(tcp_stream);
let (reader, writer) = ( let (reader, writer) = (
Reader::from_reader(BufReader::new(read)), Reader::from_reader(BufReader::new(read)),
Writer::new(write), Writer::new(write),
@ -43,85 +51,118 @@ impl<'a> StreamStart<'a> {
reader, reader,
writer, writer,
hostname, hostname,
tls_config,
buffer: vec![], buffer: vec![],
} }
} }
async fn start_stream(mut self) { async fn start_stream(self) {
match self.negotiate_stream().await { match self.negotiate_stream().await {
Ok(_) => {} StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await,
Err(err) => { StartTLSResult::Failure(mut conn, err) => {
if let Err(err2) = error(self.writer.get_mut(), err).await { if let Err(err2) = error(conn.writer.get_mut(), err).await {
error!("error writing error: {err2}"); error!("error writing error: {err2}");
return; return;
} else { } else {
info!("wrote error {err}") info!("wrote error {err}")
} }
if let Err(e) = self.writer.get_mut().write_all(b"</stream:stream>").await { if let Err(e) = conn.writer.get_mut().write_all(b"</stream:stream>").await {
error!("writing end to stream: {e}") error!("writing end to stream: {e}")
} }
if let Err(e) = self.writer.get_mut().shutdown().await { if let Err(e) = conn.writer.get_mut().shutdown().await {
error!("shutting down stream: {e}") error!("shutting down stream: {e}")
} }
} }
StartTLSResult::TLSFailure(err) => {
error!("TLS negotiation failure: {err}. Dropping connection.")
}
} }
} }
async fn negotiate_stream(&mut self) -> Result<()> { async fn negotiate_stream(mut self) -> StartTLSResult {
let attrs = loop { let attrs = loop {
match self.reader.read_event_into_async(&mut self.buffer).await? { let event = match self.reader.read_event_into_async(&mut self.buffer).await {
Ok(event) => event,
Err(err) => return StartTLSResult::Failure(self, err.into()),
};
match event {
Event::Start(start) => { Event::Start(start) => {
if start.name().as_ref() == tag::STREAM_ELEMENT_NAME { if start.name().as_ref() == tag::STREAM_ELEMENT_NAME {
let attrs: StreamAttrs = start.attributes().try_into()?; let attrs: StreamAttrs = match start.attributes().try_into() {
Ok(a) => a,
Err(err) => return StartTLSResult::Failure(self, err.into()),
};
if attrs.namespace != XMLNamespace::JabberClient { if attrs.namespace != XMLNamespace::JabberClient {
return Err(StreamError::InvalidNamespace); return StartTLSResult::Failure(self, StreamError::InvalidNamespace);
} }
break attrs; break attrs;
} else { } else {
info!("element: {:?}", start); info!("element: {:?}", start);
} }
} }
Event::End(_) => return Err(StreamError::BadFormat), Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
Event::Eof => return Err(StreamError::BadFormat), Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
_ => continue, _ => continue,
} }
}; };
info!("starting negotiation with: {attrs:?}"); info!("starting negotiation with: {attrs:?}");
self.write_stream_header(StreamAttrs { if let Err(err) = self
.write_stream_header(StreamAttrs {
from: attrs.to.clone(), from: attrs.to.clone(),
to: attrs.from, to: attrs.from,
namespace: XMLNamespace::JabberClient, namespace: XMLNamespace::JabberClient,
}) })
.await?; .await
{
return StartTLSResult::Failure(self, err);
};
if attrs.to != self.hostname { if attrs.to != self.hostname {
return Err(StreamError::HostUnknown); return StartTLSResult::Failure(self, StreamError::HostUnknown);
} }
self.send_features().await?; if let Err(err) = self.send_features().await {
return StartTLSResult::Failure(self, err);
};
loop { loop {
match self.reader.read_event_into_async(&mut self.buffer).await? { let event = match self.reader.read_event_into_async(&mut self.buffer).await {
Ok(event) => event,
Err(err) => return StartTLSResult::Failure(self, err.into()),
};
match event {
Event::Empty(empty) => match empty.name().as_ref() { Event::Empty(empty) => match empty.name().as_ref() {
tag::STARTTLS => { tag::STARTTLS => {
info!("starttls negotiation"); info!("starttls negotiation");
if let Step::Failure = negotiator::start_tls( let hostname = self.hostname;
self.reader.get_mut(), let tls_config = self.tls_config.clone();
self.writer.get_mut(), return match negotiator::start_tls(
self.reader
.into_inner()
.into_inner()
.unsplit(self.writer.into_inner()),
empty, empty,
self.tls_config.clone(),
) )
.await? .await
{ {
return Step::Failure.write_tag(self.writer.get_mut()).await; TcpConnOrTLS::TLS(conn) => {
StartTLSResult::Success(stream::TLSStream::new(conn, hostname))
}
TcpConnOrTLS::TCP(conn, err) => StartTLSResult::Failure(
StreamStart::new(conn, hostname, tls_config),
err,
),
TcpConnOrTLS::Failed(err) => StartTLSResult::TLSFailure(err),
}; };
} }
_ => return Err(StreamError::UnsupportedFeature), _ => return StartTLSResult::Failure(self, StreamError::UnsupportedFeature),
}, },
Event::End(_) => return Err(StreamError::BadFormat), Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
Event::Eof => return Err(StreamError::BadFormat), Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
_ => continue, _ => continue,
} }
} }
Err(StreamError::InternalServerError) StartTLSResult::Failure(self, StreamError::InternalServerError)
} }
async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> { async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> {
self.writer self.writer
@ -156,9 +197,15 @@ impl<'a> StreamStart<'a> {
} }
} }
pub fn spawn(hostname: String, (mut stream, _): (TcpStream, SocketAddr)) { pub fn spawn(
hostname: String,
(mut stream, _): (TcpStream, SocketAddr),
tls_config: Arc<rustls::ServerConfig>,
) {
tokio::spawn(async move { tokio::spawn(async move {
StreamStart::new(&mut stream, hostname).start_stream().await; StreamStart::new(stream, hostname, tls_config)
.start_stream()
.await;
}); });
} }

View File

@ -1 +1 @@
pub mod stream;

32
salut/src/tls/stream.rs Normal file
View File

@ -0,0 +1,32 @@
use quick_xml::{Reader, Writer};
use tokio::{
io::{BufReader, ReadHalf, WriteHalf},
net::TcpStream,
};
use tokio_rustls::server::TlsStream;
pub struct TLSStream {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
buffer: Vec<u8>,
hostname: String,
}
impl TLSStream {
pub fn new(conn: TlsStream<TcpStream>, hostname: String) -> Self {
let (read, write) = tokio::io::split(conn);
let (reader, writer) = (
Reader::from_reader(BufReader::new(read)),
Writer::new(write),
);
Self {
reader,
writer,
hostname,
buffer: vec![],
}
}
pub async fn start_stream(self) {}
}