werewolves/werewolves-proto/src/cbor.rs

157 lines
4.8 KiB
Rust
Raw Normal View History

2025-11-04 22:25:50 +00:00
use axum::{
body::Bytes,
extract::{FromRequest, Request, rejection::BytesRejection},
http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response},
};
use axum_extra::headers::Mime;
use bytes::{BufMut, BytesMut};
use core::fmt::Display;
use serde::{Serialize, de::DeserializeOwned};
const CBOR_CONTENT_TYPE: &str = "application/cbor";
const PLAIN_CONTENT_TYPE: &str = "text/plain";
#[must_use]
pub struct Cbor<T>(pub T);
impl<T> Cbor<T> {
pub const fn new(t: T) -> Self {
Self(t)
}
}
impl<T, S> FromRequest<S> for Cbor<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = CborRejection;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
if !cbor_content_type(req.headers()) {
return Err(CborRejection::MissingCborContentType);
}
let bytes = Bytes::from_request(req, state).await?;
Ok(Self(ciborium::from_reader::<T, _>(&*bytes)?))
}
}
impl<T> IntoResponse for Cbor<T>
where
T: Serialize,
{
fn into_response(self) -> axum::response::Response {
// Extracted into separate fn so it's only compiled once for all T.
fn make_response(buf: BytesMut, ser_result: Result<(), CborRejection>) -> Response {
match ser_result {
Ok(()) => (
[(
header::CONTENT_TYPE,
HeaderValue::from_static(CBOR_CONTENT_TYPE),
)],
buf.freeze(),
)
.into_response(),
Err(err) => err.into_response(),
}
}
// Use a small initial capacity of 128 bytes like serde_json::to_vec
// https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
let mut buf = BytesMut::with_capacity(128).writer();
let res = ciborium::into_writer(&self.0, &mut buf)
.map_err(|err| CborRejection::SerdeRejection(err.to_string()));
make_response(buf.into_inner(), res)
}
}
#[derive(Debug)]
pub enum CborRejection {
MissingCborContentType,
BytesRejection(BytesRejection),
DeserializeRejection(String),
SerdeRejection(String),
}
impl<T: Display> From<ciborium::de::Error<T>> for CborRejection {
fn from(value: ciborium::de::Error<T>) -> Self {
Self::SerdeRejection(match value {
ciborium::de::Error::Io(err) => format!("i/o: {err}"),
ciborium::de::Error::Syntax(offset) => format!("syntax error at {offset}"),
ciborium::de::Error::Semantic(offset, err) => format!(
"semantic parse: {err}{}",
offset
.map(|offset| format!(" at {offset}"))
.unwrap_or_default(),
),
ciborium::de::Error::RecursionLimitExceeded => {
String::from("the input caused serde to recurse too much")
}
})
}
}
impl From<BytesRejection> for CborRejection {
fn from(value: BytesRejection) -> Self {
Self::BytesRejection(value)
}
}
impl IntoResponse for CborRejection {
fn into_response(self) -> axum::response::Response {
match self {
CborRejection::MissingCborContentType => (
StatusCode::BAD_REQUEST,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
String::from("missing cbor content type"),
),
CborRejection::BytesRejection(err) => (
err.status(),
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
format!("bytes rejection: {}", err.body_text()),
),
CborRejection::SerdeRejection(err) => (
StatusCode::BAD_REQUEST,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
err,
),
CborRejection::DeserializeRejection(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
err,
),
}
.into_response()
}
}
fn cbor_content_type(headers: &HeaderMap) -> bool {
let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
return false;
};
let Ok(content_type) = content_type.to_str() else {
return false;
};
let Ok(mime) = content_type.parse::<Mime>() else {
return false;
};
mime.type_() == "application"
&& (mime.subtype() == "cbor" || mime.suffix().is_some_and(|name| name == "cbor"))
}