tls/starttls
This commit is contained in:
parent
f00fa833e0
commit
e7cf44efe1
|
@ -1322,9 +1322,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.2"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
|
||||
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
|
||||
dependencies = [
|
||||
"base64",
|
||||
]
|
||||
|
@ -1360,6 +1360,7 @@ dependencies = [
|
|||
"quick-xml",
|
||||
"rcgen",
|
||||
"rsdns",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
|
|
|
@ -21,3 +21,4 @@ instant-acme = "0.3.2"
|
|||
desec = { path = "../desec" }
|
||||
rcgen = "0.11.1"
|
||||
rsdns = { version = "0.15.0", features = ["net-tokio"] }
|
||||
rustls-pemfile = "1.0.3"
|
||||
|
|
|
@ -1,21 +1,19 @@
|
|||
use std::{
|
||||
env,
|
||||
fs::File,
|
||||
io::{prelude::Write, Read},
|
||||
net::{SocketAddr, ToSocketAddrs},
|
||||
str::FromStr,
|
||||
net::ToSocketAddrs,
|
||||
os::unix::prelude::PermissionsExt,
|
||||
time::Duration,
|
||||
vec,
|
||||
};
|
||||
|
||||
use desec::{
|
||||
dns::{RRSet, RRSetPatch, Record},
|
||||
Session,
|
||||
};
|
||||
use desec::dns::{RRSet, RRSetPatch, Record};
|
||||
use instant_acme::{
|
||||
Account, AuthorizationStatus, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
|
||||
Order, OrderStatus,
|
||||
};
|
||||
use log::{debug, error, info, warn};
|
||||
use log::{debug, error, info};
|
||||
use rcgen::{Certificate, CertificateParams, DistinguishedName};
|
||||
use rsdns::{
|
||||
clients::{tokio::Client, ClientConfig},
|
||||
|
@ -23,18 +21,21 @@ use rsdns::{
|
|||
records::data,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio_rustls::rustls;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum CertStore {
|
||||
Provision,
|
||||
Existing(CertificatePEM),
|
||||
Existing(CertPaths),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub domain: String,
|
||||
pub subdomain: Option<String>,
|
||||
pub port: u16,
|
||||
pub insecure_port: u16,
|
||||
pub tls_port: u16,
|
||||
pub cert_store: CertStore,
|
||||
pub desec_cfg: DesecConfig,
|
||||
#[serde(skip)]
|
||||
|
@ -44,18 +45,56 @@ pub struct Config {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DesecConfig {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
#[serde(skip)]
|
||||
pub password: Option<String>,
|
||||
pub name_servers: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CertificatePEM {
|
||||
pub cert_chain_pem: String,
|
||||
pub private_key_pem: String,
|
||||
pub struct CertPaths {
|
||||
pub cert_chain_path: String,
|
||||
pub private_key_path: String,
|
||||
}
|
||||
|
||||
impl CertPaths {
|
||||
pub fn into_config(&self) -> Result<rustls::ServerConfig, anyhow::Error> {
|
||||
let mut buf = std::io::BufReader::new(File::open(&self.cert_chain_path)?);
|
||||
let certs = rustls_pemfile::certs(&mut buf)?
|
||||
.into_iter()
|
||||
.map(tokio_rustls::rustls::Certificate)
|
||||
.collect();
|
||||
|
||||
let mut reader = std::io::BufReader::new(File::open(&self.private_key_path)?);
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
|
||||
let key = match keys.len() {
|
||||
0 => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"No PKCS8-encoded private key found in {}",
|
||||
&self.private_key_path,
|
||||
))
|
||||
}
|
||||
1 => rustls::PrivateKey(keys.remove(0)),
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"More than one PKCS8-encoded private key found in {}",
|
||||
&self.private_key_path,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?)
|
||||
}
|
||||
}
|
||||
|
||||
const CONFIG_PATHS: [&str; 3] = [DEFAULT_PATH, "/etc/salut.toml", "/usr/local/etc/salut.toml"];
|
||||
const CERT_PATHS: [&str; 3] = ["/etc/salut/", "/usr/local/etc/salut/", "./"];
|
||||
const CERT_FILENAME: &str = "salut.cert";
|
||||
const CERT_PRIVKEY_FILENAME: &str = "salut.pk";
|
||||
const DNS_QUERY_WAIT: Duration = Duration::from_millis(250);
|
||||
const DNS_PASSWORD_VAR: &str = "DESEC_PASSWORD";
|
||||
pub const DEFAULT_PATH: &str = "salut.toml";
|
||||
|
||||
impl Default for Config {
|
||||
|
@ -63,11 +102,12 @@ impl Default for Config {
|
|||
Self {
|
||||
domain: String::new(),
|
||||
subdomain: Some(String::new()),
|
||||
port: 5222,
|
||||
insecure_port: 5222,
|
||||
tls_port: 5223,
|
||||
cert_store: CertStore::Provision,
|
||||
desec_cfg: DesecConfig {
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
password: None,
|
||||
name_servers: vec!["ns1.desec.io".into(), "ns2.desec.org".into()],
|
||||
},
|
||||
original_path: DEFAULT_PATH,
|
||||
|
@ -83,6 +123,9 @@ impl Config {
|
|||
file.read_to_string(&mut cfg)?;
|
||||
let mut cfg: Self = toml::from_str(&cfg)?;
|
||||
cfg.original_path = path;
|
||||
if let Ok(pass) = env::var(DNS_PASSWORD_VAR) {
|
||||
cfg.desec_cfg.password = Some(pass);
|
||||
}
|
||||
return Ok(cfg);
|
||||
};
|
||||
}
|
||||
|
@ -111,7 +154,7 @@ const ACME_PREFIX: &str = "_acme-challenge";
|
|||
|
||||
impl Config {
|
||||
// Returns existing certificate or provisions a new one via DNS challenge using DeSEC
|
||||
pub async fn certificate(&self) -> Result<CertificatePEM, anyhow::Error> {
|
||||
pub async fn certificate(&self) -> Result<CertPaths, anyhow::Error> {
|
||||
let desec_cfg = match self.cert_store.clone() {
|
||||
CertStore::Provision => self.desec_cfg.clone(),
|
||||
CertStore::Existing(existing) => return Ok(existing),
|
||||
|
@ -141,7 +184,11 @@ impl Config {
|
|||
assert!(matches!(state.status, OrderStatus::Pending));
|
||||
|
||||
debug!("logging into desec as <{}>", &desec_cfg.username);
|
||||
let dns = desec::Session::login(&desec_cfg.username, &desec_cfg.password).await?;
|
||||
let desec_pass = match desec_cfg.password {
|
||||
Some(pass) => pass,
|
||||
None => panic!("need desec password for provisioning the TLS certificate, please set {DNS_PASSWORD_VAR}"),
|
||||
};
|
||||
let dns = desec::Session::login(&desec_cfg.username, &desec_pass).await?;
|
||||
debug!("querying existing TXT records");
|
||||
let existing_records: Vec<RRSet> = dns
|
||||
.get_rrsets(&self.domain, Some(vec![Record::TXT]))
|
||||
|
@ -258,16 +305,12 @@ impl Config {
|
|||
}
|
||||
};
|
||||
|
||||
let cert = CertificatePEM {
|
||||
cert_chain_pem,
|
||||
private_key_pem: cert.serialize_private_key_pem(),
|
||||
};
|
||||
|
||||
let mut new_cfg = self.clone();
|
||||
new_cfg.cert_store = CertStore::Existing(cert.clone());
|
||||
let paths = save_certs_pam(cert_chain_pem, cert.serialize_private_key_pem()).await?;
|
||||
new_cfg.cert_store = CertStore::Existing(paths.clone());
|
||||
new_cfg.save(self.original_path)?;
|
||||
|
||||
Ok(cert)
|
||||
Ok(paths)
|
||||
}
|
||||
|
||||
async fn wait_challenges(
|
||||
|
@ -362,3 +405,36 @@ impl Config {
|
|||
}
|
||||
}
|
||||
}
|
||||
async fn save_certs_pam(
|
||||
cert_chain: String,
|
||||
private_key: String,
|
||||
) -> Result<CertPaths, anyhow::Error> {
|
||||
for try_path in CERT_PATHS {
|
||||
if let Err(_) = std::fs::create_dir_all(try_path) {
|
||||
continue;
|
||||
}
|
||||
let cert_chain_path = format!("{try_path}{CERT_FILENAME}");
|
||||
let mut cert_file = match tokio::fs::File::create(&cert_chain_path).await {
|
||||
Ok(f) => f,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let private_key_path = format!("{try_path}{CERT_PRIVKEY_FILENAME}");
|
||||
let mut pk_file = tokio::fs::File::create(&private_key_path).await?;
|
||||
pk_file.metadata().await?.permissions().set_mode(600);
|
||||
cert_file.metadata().await?.permissions().set_mode(600);
|
||||
|
||||
cert_file.write_all(cert_chain.as_bytes()).await?;
|
||||
pk_file.write_all(private_key.as_bytes()).await?;
|
||||
cert_file.flush().await?;
|
||||
pk_file.flush().await?;
|
||||
|
||||
return Ok(CertPaths {
|
||||
cert_chain_path,
|
||||
private_key_path,
|
||||
});
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!(
|
||||
"could not create/save cert files in {CERT_PATHS:?}"
|
||||
))
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use enum_display::EnumDisplay;
|
||||
use std::string::FromUtf8Error;
|
||||
use tokio_rustls::rustls;
|
||||
|
||||
use log::error;
|
||||
use quick_xml::events::attributes::AttrError;
|
||||
|
@ -126,6 +127,37 @@ pub enum StreamError {
|
|||
/// the server.
|
||||
UnsupportedVersion,
|
||||
}
|
||||
impl From<rustls::Error> for StreamError {
|
||||
fn from(value: rustls::Error) -> Self {
|
||||
match value {
|
||||
rustls::Error::InappropriateMessage {
|
||||
expect_types,
|
||||
got_type,
|
||||
} => Self::BadFormat,
|
||||
rustls::Error::InappropriateHandshakeMessage {
|
||||
expect_types,
|
||||
got_type,
|
||||
} => Self::BadFormat,
|
||||
rustls::Error::InvalidMessage(_) => Self::BadFormat,
|
||||
rustls::Error::NoCertificatesPresented => Self::InternalServerError,
|
||||
rustls::Error::UnsupportedNameType => Self::BadFormat,
|
||||
rustls::Error::DecryptError => Self::InternalServerError,
|
||||
rustls::Error::EncryptError => Self::InternalServerError,
|
||||
rustls::Error::PeerIncompatible(_) => Self::BadFormat,
|
||||
rustls::Error::PeerMisbehaved(_) => Self::BadFormat,
|
||||
rustls::Error::AlertReceived(_) => Self::BadFormat,
|
||||
rustls::Error::InvalidCertificate(_) => Self::InternalServerError,
|
||||
_ => Self::BadFormat,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for StreamError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
error!("io error: {value}");
|
||||
Self::InternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FromUtf8Error> for StreamError {
|
||||
fn from(_: FromUtf8Error) -> Self {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::process;
|
||||
use std::{process, sync::Arc};
|
||||
|
||||
use log::{error, info};
|
||||
|
||||
|
@ -33,10 +33,27 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
}
|
||||
};
|
||||
info!("checking for certificates");
|
||||
let certs = cfg.certificate().await.expect("getting certificates");
|
||||
let certs = Arc::new(
|
||||
cfg.certificate()
|
||||
.await
|
||||
.expect("getting certificates")
|
||||
.into_config()?,
|
||||
);
|
||||
|
||||
let host = cfg.hostname();
|
||||
info!("listening on {host}:{}!", cfg.port);
|
||||
server::listen(host, cfg.port).await.unwrap();
|
||||
info!(
|
||||
"listening on {host} {} (plain) and {} (tls)!",
|
||||
cfg.insecure_port, cfg.tls_port
|
||||
);
|
||||
let cert_clone = certs.clone();
|
||||
let host_clone = host.clone();
|
||||
tokio::spawn(async move {
|
||||
server::listen_tls(host_clone, cfg.tls_port, cert_clone)
|
||||
.await
|
||||
.expect("TLS listener")
|
||||
});
|
||||
server::listen_starttls(host, cfg.insecure_port, certs)
|
||||
.await
|
||||
.expect("STARTTLS listener");
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use quick_xml::{
|
||||
events::{BytesStart, Event},
|
||||
Writer,
|
||||
};
|
||||
use tokio::io::{AsyncBufRead, AsyncWrite};
|
||||
use tokio_rustls::rustls;
|
||||
use tokio::{io::AsyncWrite, net::TcpStream};
|
||||
use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
|
||||
|
||||
use crate::{
|
||||
error::StreamError,
|
||||
streamstart::{StartTLSResult, StreamStart},
|
||||
tag::{self, Tag},
|
||||
tls::stream::{self, TLSStream},
|
||||
};
|
||||
|
||||
pub enum Step {
|
||||
|
@ -36,15 +40,20 @@ impl Tag for Step {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn start_tls<R, W>(
|
||||
reader: R,
|
||||
writer: W,
|
||||
pub(crate) enum TcpConnOrTLS {
|
||||
TCP(TcpStream, StreamError),
|
||||
/// Failed is returned when there's a failure starting the TLS stream
|
||||
/// thus the TCP stream is consumed.
|
||||
Failed(StreamError),
|
||||
TLS(TlsStream<TcpStream>),
|
||||
}
|
||||
|
||||
pub(crate) async fn start_tls(
|
||||
conn: TcpStream,
|
||||
start_tls_event: BytesStart<'_>,
|
||||
) -> Result<Step, StreamError>
|
||||
where
|
||||
R: AsyncBufRead + Unpin,
|
||||
W: AsyncWrite + Unpin + Send,
|
||||
{
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> TcpConnOrTLS {
|
||||
let mut conn = conn;
|
||||
match start_tls_event.try_get_attribute("xmlns") {
|
||||
Ok(namespace) => {
|
||||
if &namespace
|
||||
|
@ -52,28 +61,26 @@ where
|
|||
.unwrap_or_default()
|
||||
!= tag::TLS_NAMESPACE
|
||||
{
|
||||
return Ok(Step::Failure);
|
||||
if let Err(err) = Step::Failure.write_tag(&mut conn).await {
|
||||
return TcpConnOrTLS::TCP(conn, err);
|
||||
}
|
||||
return TcpConnOrTLS::TCP(conn, StreamError::InvalidNamespace);
|
||||
}
|
||||
}
|
||||
Err(_) => return Ok(Step::Failure),
|
||||
Err(_) => {
|
||||
if let Err(err) = Step::Failure.write_tag(&mut conn).await {
|
||||
return TcpConnOrTLS::TCP(conn, err);
|
||||
}
|
||||
return TcpConnOrTLS::TCP(conn, StreamError::BadFormat);
|
||||
}
|
||||
}
|
||||
|
||||
// let config = rustls::ServerConfig::builder()
|
||||
// .with_safe_defaults()
|
||||
// .with_no_client_auth()
|
||||
// .with_single_cert(certs, keys.remove(0))
|
||||
// .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||
|
||||
Step::Proceed.write_tag(writer).await?;
|
||||
|
||||
// match TlsConnector::builder(). {
|
||||
// Ok(conn) => conn.,
|
||||
// Err(err) => {
|
||||
// error!("getting a tls connector: {err}");
|
||||
// return Ok(Step::Failure);
|
||||
// }
|
||||
// }
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_secs(3));
|
||||
todo!()
|
||||
if let Err(err) = Step::Proceed.write_tag(&mut conn).await {
|
||||
return TcpConnOrTLS::TCP(conn, err);
|
||||
};
|
||||
let acceptor = TlsAcceptor::from(tls_config);
|
||||
match acceptor.accept(conn).await {
|
||||
Ok(tls_conn) => TcpConnOrTLS::TLS(tls_conn),
|
||||
Err(err) => TcpConnOrTLS::Failed(err.into()),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,15 +1,53 @@
|
|||
use log::info;
|
||||
use std::sync::Arc;
|
||||
|
||||
use log::{error, info};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::{rustls, TlsAcceptor};
|
||||
|
||||
use crate::streamstart;
|
||||
use crate::{streamstart, tls::stream};
|
||||
|
||||
pub async fn listen(hostname: String, port: u16) -> Result<(), anyhow::Error> {
|
||||
let listener = TcpListener::bind(("0.0.0.0", port)).await?;
|
||||
pub async fn listen_starttls(
|
||||
hostname: String,
|
||||
insecure_port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let listener = TcpListener::bind(("0.0.0.0", insecure_port)).await?;
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok(conn) => {
|
||||
info!("opening connection from {}", conn.1);
|
||||
streamstart::spawn(hostname.clone(), conn);
|
||||
info!("opening starttls connection from {}", conn.1);
|
||||
streamstart::spawn(hostname.clone(), conn, tls_config.clone());
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("listening: {e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn listen_tls(
|
||||
hostname: String,
|
||||
tls_port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let listener = TcpListener::bind(("0.0.0.0", tls_port)).await?;
|
||||
let acceptor = TlsAcceptor::from(tls_config);
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok(conn) => {
|
||||
info!("opening TLS connection from {}", conn.1);
|
||||
match acceptor.accept(conn.0).await {
|
||||
Ok(conn) => {
|
||||
let cloned_hostname = hostname.clone();
|
||||
tokio::spawn(async move {
|
||||
stream::TLSStream::new(conn, cloned_hostname)
|
||||
.start_stream()
|
||||
.await
|
||||
});
|
||||
}
|
||||
Err(err) => error!("TLS accept error for {}: {err}", conn.1),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("listening: {e}");
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::net::SocketAddr;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use log::{error, info};
|
||||
use quick_xml::{
|
||||
|
@ -6,34 +6,42 @@ use quick_xml::{
|
|||
Reader, Writer,
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncWrite, AsyncWriteExt, BufReader},
|
||||
net::{
|
||||
tcp::{ReadHalf, WriteHalf},
|
||||
TcpStream,
|
||||
},
|
||||
io::{AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
|
||||
net::TcpStream,
|
||||
};
|
||||
use tokio_rustls::rustls;
|
||||
|
||||
use crate::{
|
||||
error::StreamError,
|
||||
feature::Feature,
|
||||
negotiator::{self, Step},
|
||||
negotiator::{self, Step, TcpConnOrTLS},
|
||||
tag::{self, Tag},
|
||||
tls::stream,
|
||||
};
|
||||
|
||||
pub(crate) enum StartTLSResult {
|
||||
Success(stream::TLSStream),
|
||||
Failure(StreamStart, StreamError),
|
||||
/// TLSFailure is returned when the TLS negotiation failed, and the TCP
|
||||
/// stream is to be dropped.
|
||||
TLSFailure(StreamError),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, StreamError>;
|
||||
|
||||
const FEATURES: &'static [Feature] = &[Feature::start_tls(true)];
|
||||
|
||||
struct StreamStart<'a> {
|
||||
reader: Reader<BufReader<ReadHalf<'a>>>,
|
||||
writer: Writer<WriteHalf<'a>>,
|
||||
pub struct StreamStart {
|
||||
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
|
||||
writer: Writer<WriteHalf<TcpStream>>,
|
||||
buffer: Vec<u8>,
|
||||
hostname: String,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
}
|
||||
|
||||
impl<'a> StreamStart<'a> {
|
||||
fn new(stream: &'a mut TcpStream, hostname: String) -> Self {
|
||||
let (read, write) = stream.split();
|
||||
impl StreamStart {
|
||||
fn new(tcp_stream: TcpStream, hostname: String, tls_config: Arc<rustls::ServerConfig>) -> Self {
|
||||
let (read, write) = tokio::io::split(tcp_stream);
|
||||
let (reader, writer) = (
|
||||
Reader::from_reader(BufReader::new(read)),
|
||||
Writer::new(write),
|
||||
|
@ -43,85 +51,118 @@ impl<'a> StreamStart<'a> {
|
|||
reader,
|
||||
writer,
|
||||
hostname,
|
||||
tls_config,
|
||||
buffer: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_stream(mut self) {
|
||||
async fn start_stream(self) {
|
||||
match self.negotiate_stream().await {
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
if let Err(err2) = error(self.writer.get_mut(), err).await {
|
||||
StartTLSResult::Success(tls_stream) => tls_stream.start_stream().await,
|
||||
StartTLSResult::Failure(mut conn, err) => {
|
||||
if let Err(err2) = error(conn.writer.get_mut(), err).await {
|
||||
error!("error writing error: {err2}");
|
||||
return;
|
||||
} else {
|
||||
info!("wrote error {err}")
|
||||
}
|
||||
|
||||
if let Err(e) = self.writer.get_mut().write_all(b"</stream:stream>").await {
|
||||
if let Err(e) = conn.writer.get_mut().write_all(b"</stream:stream>").await {
|
||||
error!("writing end to stream: {e}")
|
||||
}
|
||||
if let Err(e) = self.writer.get_mut().shutdown().await {
|
||||
if let Err(e) = conn.writer.get_mut().shutdown().await {
|
||||
error!("shutting down stream: {e}")
|
||||
}
|
||||
}
|
||||
StartTLSResult::TLSFailure(err) => {
|
||||
error!("TLS negotiation failure: {err}. Dropping connection.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn negotiate_stream(&mut self) -> Result<()> {
|
||||
async fn negotiate_stream(mut self) -> StartTLSResult {
|
||||
let attrs = loop {
|
||||
match self.reader.read_event_into_async(&mut self.buffer).await? {
|
||||
let event = match self.reader.read_event_into_async(&mut self.buffer).await {
|
||||
Ok(event) => event,
|
||||
Err(err) => return StartTLSResult::Failure(self, err.into()),
|
||||
};
|
||||
match event {
|
||||
Event::Start(start) => {
|
||||
if start.name().as_ref() == tag::STREAM_ELEMENT_NAME {
|
||||
let attrs: StreamAttrs = start.attributes().try_into()?;
|
||||
let attrs: StreamAttrs = match start.attributes().try_into() {
|
||||
Ok(a) => a,
|
||||
Err(err) => return StartTLSResult::Failure(self, err.into()),
|
||||
};
|
||||
if attrs.namespace != XMLNamespace::JabberClient {
|
||||
return Err(StreamError::InvalidNamespace);
|
||||
return StartTLSResult::Failure(self, StreamError::InvalidNamespace);
|
||||
}
|
||||
break attrs;
|
||||
} else {
|
||||
info!("element: {:?}", start);
|
||||
}
|
||||
}
|
||||
Event::End(_) => return Err(StreamError::BadFormat),
|
||||
Event::Eof => return Err(StreamError::BadFormat),
|
||||
Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||
Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||
_ => continue,
|
||||
}
|
||||
};
|
||||
info!("starting negotiation with: {attrs:?}");
|
||||
self.write_stream_header(StreamAttrs {
|
||||
from: attrs.to.clone(),
|
||||
to: attrs.from,
|
||||
namespace: XMLNamespace::JabberClient,
|
||||
})
|
||||
.await?;
|
||||
if let Err(err) = self
|
||||
.write_stream_header(StreamAttrs {
|
||||
from: attrs.to.clone(),
|
||||
to: attrs.from,
|
||||
namespace: XMLNamespace::JabberClient,
|
||||
})
|
||||
.await
|
||||
{
|
||||
return StartTLSResult::Failure(self, err);
|
||||
};
|
||||
if attrs.to != self.hostname {
|
||||
return Err(StreamError::HostUnknown);
|
||||
return StartTLSResult::Failure(self, StreamError::HostUnknown);
|
||||
}
|
||||
self.send_features().await?;
|
||||
if let Err(err) = self.send_features().await {
|
||||
return StartTLSResult::Failure(self, err);
|
||||
};
|
||||
loop {
|
||||
match self.reader.read_event_into_async(&mut self.buffer).await? {
|
||||
let event = match self.reader.read_event_into_async(&mut self.buffer).await {
|
||||
Ok(event) => event,
|
||||
Err(err) => return StartTLSResult::Failure(self, err.into()),
|
||||
};
|
||||
match event {
|
||||
Event::Empty(empty) => match empty.name().as_ref() {
|
||||
tag::STARTTLS => {
|
||||
info!("starttls negotiation");
|
||||
if let Step::Failure = negotiator::start_tls(
|
||||
self.reader.get_mut(),
|
||||
self.writer.get_mut(),
|
||||
let hostname = self.hostname;
|
||||
let tls_config = self.tls_config.clone();
|
||||
return match negotiator::start_tls(
|
||||
self.reader
|
||||
.into_inner()
|
||||
.into_inner()
|
||||
.unsplit(self.writer.into_inner()),
|
||||
empty,
|
||||
self.tls_config.clone(),
|
||||
)
|
||||
.await?
|
||||
.await
|
||||
{
|
||||
return Step::Failure.write_tag(self.writer.get_mut()).await;
|
||||
TcpConnOrTLS::TLS(conn) => {
|
||||
StartTLSResult::Success(stream::TLSStream::new(conn, hostname))
|
||||
}
|
||||
TcpConnOrTLS::TCP(conn, err) => StartTLSResult::Failure(
|
||||
StreamStart::new(conn, hostname, tls_config),
|
||||
err,
|
||||
),
|
||||
TcpConnOrTLS::Failed(err) => StartTLSResult::TLSFailure(err),
|
||||
};
|
||||
}
|
||||
_ => return Err(StreamError::UnsupportedFeature),
|
||||
_ => return StartTLSResult::Failure(self, StreamError::UnsupportedFeature),
|
||||
},
|
||||
Event::End(_) => return Err(StreamError::BadFormat),
|
||||
Event::Eof => return Err(StreamError::BadFormat),
|
||||
Event::End(_) => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||
Event::Eof => return StartTLSResult::Failure(self, StreamError::BadFormat),
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Err(StreamError::InternalServerError)
|
||||
StartTLSResult::Failure(self, StreamError::InternalServerError)
|
||||
}
|
||||
async fn write_stream_header(&mut self, req: StreamAttrs) -> Result<()> {
|
||||
self.writer
|
||||
|
@ -156,9 +197,15 @@ impl<'a> StreamStart<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn spawn(hostname: String, (mut stream, _): (TcpStream, SocketAddr)) {
|
||||
pub fn spawn(
|
||||
hostname: String,
|
||||
(mut stream, _): (TcpStream, SocketAddr),
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
StreamStart::new(&mut stream, hostname).start_stream().await;
|
||||
StreamStart::new(stream, hostname, tls_config)
|
||||
.start_stream()
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
|
||||
pub mod stream;
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
use quick_xml::{Reader, Writer};
|
||||
use tokio::{
|
||||
io::{BufReader, ReadHalf, WriteHalf},
|
||||
net::TcpStream,
|
||||
};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
pub struct TLSStream {
|
||||
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
|
||||
writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
|
||||
buffer: Vec<u8>,
|
||||
hostname: String,
|
||||
}
|
||||
|
||||
impl TLSStream {
|
||||
pub fn new(conn: TlsStream<TcpStream>, hostname: String) -> Self {
|
||||
let (read, write) = tokio::io::split(conn);
|
||||
let (reader, writer) = (
|
||||
Reader::from_reader(BufReader::new(read)),
|
||||
Writer::new(write),
|
||||
);
|
||||
|
||||
Self {
|
||||
reader,
|
||||
writer,
|
||||
hostname,
|
||||
buffer: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_stream(self) {}
|
||||
}
|
Loading…
Reference in New Issue