157 lines
4.8 KiB
Rust
157 lines
4.8 KiB
Rust
|
|
use axum::{
|
||
|
|
body::Bytes,
|
||
|
|
extract::{FromRequest, Request, rejection::BytesRejection},
|
||
|
|
http::{HeaderMap, HeaderValue, StatusCode, header},
|
||
|
|
response::{IntoResponse, Response},
|
||
|
|
};
|
||
|
|
use axum_extra::headers::Mime;
|
||
|
|
use bytes::{BufMut, BytesMut};
|
||
|
|
use core::fmt::Display;
|
||
|
|
use serde::{Serialize, de::DeserializeOwned};
|
||
|
|
|
||
|
|
const CBOR_CONTENT_TYPE: &str = "application/cbor";
|
||
|
|
const PLAIN_CONTENT_TYPE: &str = "text/plain";
|
||
|
|
|
||
|
|
#[must_use]
|
||
|
|
pub struct Cbor<T>(pub T);
|
||
|
|
|
||
|
|
impl<T> Cbor<T> {
|
||
|
|
pub const fn new(t: T) -> Self {
|
||
|
|
Self(t)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl<T, S> FromRequest<S> for Cbor<T>
|
||
|
|
where
|
||
|
|
T: DeserializeOwned,
|
||
|
|
S: Send + Sync,
|
||
|
|
{
|
||
|
|
type Rejection = CborRejection;
|
||
|
|
|
||
|
|
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
||
|
|
if !cbor_content_type(req.headers()) {
|
||
|
|
return Err(CborRejection::MissingCborContentType);
|
||
|
|
}
|
||
|
|
|
||
|
|
let bytes = Bytes::from_request(req, state).await?;
|
||
|
|
Ok(Self(ciborium::from_reader::<T, _>(&*bytes)?))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl<T> IntoResponse for Cbor<T>
|
||
|
|
where
|
||
|
|
T: Serialize,
|
||
|
|
{
|
||
|
|
fn into_response(self) -> axum::response::Response {
|
||
|
|
// Extracted into separate fn so it's only compiled once for all T.
|
||
|
|
fn make_response(buf: BytesMut, ser_result: Result<(), CborRejection>) -> Response {
|
||
|
|
match ser_result {
|
||
|
|
Ok(()) => (
|
||
|
|
[(
|
||
|
|
header::CONTENT_TYPE,
|
||
|
|
HeaderValue::from_static(CBOR_CONTENT_TYPE),
|
||
|
|
)],
|
||
|
|
buf.freeze(),
|
||
|
|
)
|
||
|
|
.into_response(),
|
||
|
|
Err(err) => err.into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Use a small initial capacity of 128 bytes like serde_json::to_vec
|
||
|
|
// https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
|
||
|
|
let mut buf = BytesMut::with_capacity(128).writer();
|
||
|
|
let res = ciborium::into_writer(&self.0, &mut buf)
|
||
|
|
.map_err(|err| CborRejection::SerdeRejection(err.to_string()));
|
||
|
|
make_response(buf.into_inner(), res)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub enum CborRejection {
|
||
|
|
MissingCborContentType,
|
||
|
|
BytesRejection(BytesRejection),
|
||
|
|
DeserializeRejection(String),
|
||
|
|
SerdeRejection(String),
|
||
|
|
}
|
||
|
|
impl<T: Display> From<ciborium::de::Error<T>> for CborRejection {
|
||
|
|
fn from(value: ciborium::de::Error<T>) -> Self {
|
||
|
|
Self::SerdeRejection(match value {
|
||
|
|
ciborium::de::Error::Io(err) => format!("i/o: {err}"),
|
||
|
|
ciborium::de::Error::Syntax(offset) => format!("syntax error at {offset}"),
|
||
|
|
ciborium::de::Error::Semantic(offset, err) => format!(
|
||
|
|
"semantic parse: {err}{}",
|
||
|
|
offset
|
||
|
|
.map(|offset| format!(" at {offset}"))
|
||
|
|
.unwrap_or_default(),
|
||
|
|
),
|
||
|
|
ciborium::de::Error::RecursionLimitExceeded => {
|
||
|
|
String::from("the input caused serde to recurse too much")
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl From<BytesRejection> for CborRejection {
|
||
|
|
fn from(value: BytesRejection) -> Self {
|
||
|
|
Self::BytesRejection(value)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl IntoResponse for CborRejection {
|
||
|
|
fn into_response(self) -> axum::response::Response {
|
||
|
|
match self {
|
||
|
|
CborRejection::MissingCborContentType => (
|
||
|
|
StatusCode::BAD_REQUEST,
|
||
|
|
[(
|
||
|
|
header::CONTENT_TYPE,
|
||
|
|
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||
|
|
)],
|
||
|
|
String::from("missing cbor content type"),
|
||
|
|
),
|
||
|
|
CborRejection::BytesRejection(err) => (
|
||
|
|
err.status(),
|
||
|
|
[(
|
||
|
|
header::CONTENT_TYPE,
|
||
|
|
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||
|
|
)],
|
||
|
|
format!("bytes rejection: {}", err.body_text()),
|
||
|
|
),
|
||
|
|
CborRejection::SerdeRejection(err) => (
|
||
|
|
StatusCode::BAD_REQUEST,
|
||
|
|
[(
|
||
|
|
header::CONTENT_TYPE,
|
||
|
|
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||
|
|
)],
|
||
|
|
err,
|
||
|
|
),
|
||
|
|
CborRejection::DeserializeRejection(err) => (
|
||
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||
|
|
[(
|
||
|
|
header::CONTENT_TYPE,
|
||
|
|
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||
|
|
)],
|
||
|
|
err,
|
||
|
|
),
|
||
|
|
}
|
||
|
|
.into_response()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn cbor_content_type(headers: &HeaderMap) -> bool {
|
||
|
|
let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
|
||
|
|
return false;
|
||
|
|
};
|
||
|
|
|
||
|
|
let Ok(content_type) = content_type.to_str() else {
|
||
|
|
return false;
|
||
|
|
};
|
||
|
|
|
||
|
|
let Ok(mime) = content_type.parse::<Mime>() else {
|
||
|
|
return false;
|
||
|
|
};
|
||
|
|
|
||
|
|
mime.type_() == "application"
|
||
|
|
&& (mime.subtype() == "cbor" || mime.suffix().is_some_and(|name| name == "cbor"))
|
||
|
|
}
|