use axum::{ body::Bytes, extract::{FromRequest, Request, rejection::BytesRejection}, http::{HeaderMap, HeaderValue, StatusCode, header}, response::{IntoResponse, Response}, }; use axum_extra::headers::Mime; use bytes::{BufMut, BytesMut}; use core::fmt::Display; use serde::{Serialize, de::DeserializeOwned}; const CBOR_CONTENT_TYPE: &str = "application/cbor"; const PLAIN_CONTENT_TYPE: &str = "text/plain"; #[must_use] pub struct Cbor(pub T); impl Cbor { pub const fn new(t: T) -> Self { Self(t) } } impl FromRequest for Cbor where T: DeserializeOwned, S: Send + Sync, { type Rejection = CborRejection; async fn from_request(req: Request, state: &S) -> Result { if !cbor_content_type(req.headers()) { return Err(CborRejection::MissingCborContentType); } let bytes = Bytes::from_request(req, state).await?; Ok(Self(ciborium::from_reader::(&*bytes)?)) } } impl IntoResponse for Cbor where T: Serialize, { fn into_response(self) -> axum::response::Response { // Extracted into separate fn so it's only compiled once for all T. fn make_response(buf: BytesMut, ser_result: Result<(), CborRejection>) -> Response { match ser_result { Ok(()) => ( [( header::CONTENT_TYPE, HeaderValue::from_static(CBOR_CONTENT_TYPE), )], buf.freeze(), ) .into_response(), Err(err) => err.into_response(), } } // Use a small initial capacity of 128 bytes like serde_json::to_vec // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189 let mut buf = BytesMut::with_capacity(128).writer(); let res = ciborium::into_writer(&self.0, &mut buf) .map_err(|err| CborRejection::SerdeRejection(err.to_string())); make_response(buf.into_inner(), res) } } #[derive(Debug)] pub enum CborRejection { MissingCborContentType, BytesRejection(BytesRejection), DeserializeRejection(String), SerdeRejection(String), } impl From> for CborRejection { fn from(value: ciborium::de::Error) -> Self { Self::SerdeRejection(match value { ciborium::de::Error::Io(err) => format!("i/o: {err}"), ciborium::de::Error::Syntax(offset) => format!("syntax error at {offset}"), ciborium::de::Error::Semantic(offset, err) => format!( "semantic parse: {err}{}", offset .map(|offset| format!(" at {offset}")) .unwrap_or_default(), ), ciborium::de::Error::RecursionLimitExceeded => { String::from("the input caused serde to recurse too much") } }) } } impl From for CborRejection { fn from(value: BytesRejection) -> Self { Self::BytesRejection(value) } } impl IntoResponse for CborRejection { fn into_response(self) -> axum::response::Response { match self { CborRejection::MissingCborContentType => ( StatusCode::BAD_REQUEST, [( header::CONTENT_TYPE, HeaderValue::from_static(PLAIN_CONTENT_TYPE), )], String::from("missing cbor content type"), ), CborRejection::BytesRejection(err) => ( err.status(), [( header::CONTENT_TYPE, HeaderValue::from_static(PLAIN_CONTENT_TYPE), )], format!("bytes rejection: {}", err.body_text()), ), CborRejection::SerdeRejection(err) => ( StatusCode::BAD_REQUEST, [( header::CONTENT_TYPE, HeaderValue::from_static(PLAIN_CONTENT_TYPE), )], err, ), CborRejection::DeserializeRejection(err) => ( StatusCode::INTERNAL_SERVER_ERROR, [( header::CONTENT_TYPE, HeaderValue::from_static(PLAIN_CONTENT_TYPE), )], err, ), } .into_response() } } fn cbor_content_type(headers: &HeaderMap) -> bool { let Some(content_type) = headers.get(header::CONTENT_TYPE) else { return false; }; let Ok(content_type) = content_type.to_str() else { return false; }; let Ok(mime) = content_type.parse::() else { return false; }; mime.type_() == "application" && (mime.subtype() == "cbor" || mime.suffix().is_some_and(|name| name == "cbor")) }