werewolves/werewolves-server/src/host.rs

150 lines
4.5 KiB
Rust

use core::net::SocketAddr;
use axum::{
extract::{
ConnectInfo, State, WebSocketUpgrade,
ws::{self, Message, WebSocket},
},
response::IntoResponse,
};
use axum_extra::{TypedHeader, headers};
use colored::Colorize;
use tokio::sync::{broadcast::Receiver, mpsc::Sender};
use werewolves_proto::message::host::{HostMessage, ServerToHostMessage};
use crate::{AppState, LogError, XForwardedFor};
pub async fn handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
x_forwarded_for: Option<TypedHeader<XForwardedFor>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<AppState>,
) -> impl IntoResponse {
let who = x_forwarded_for
.map(|x| x.to_string())
.unwrap_or_else(|| addr.to_string());
log::info!(
"{who}{} connected.",
user_agent
.map(|agent| format!(" (User-Agent: {})", agent.as_str()))
.unwrap_or_default(),
);
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| async move {
Host::new(
socket,
state.host_send.clone(),
state.host_recv.resubscribe(),
)
.run()
.await
})
}
struct Host {
socket: WebSocket,
host_send: Sender<HostMessage>,
server_recv: Receiver<ServerToHostMessage>,
}
impl Host {
pub fn new(
socket: WebSocket,
host_send: Sender<HostMessage>,
server_recv: Receiver<ServerToHostMessage>,
) -> Self {
Self {
host_send,
server_recv,
socket,
}
}
async fn on_recv(
&mut self,
msg: Option<Result<Message, axum::Error>>,
) -> Result<(), anyhow::Error> {
#[cfg(not(feature = "cbor"))]
let msg: HostMessage = serde_json::from_slice(
&match msg {
Some(Ok(msg)) => msg,
Some(Err(err)) => return Err(err.into()),
None => {
log::warn!("[host] no message");
return Ok(());
}
}
.into_data(),
)?;
#[cfg(feature = "cbor")]
let msg: HostMessage = {
let bytes = match msg {
Some(Ok(msg)) => msg.into_data(),
Some(Err(err)) => return Err(err.into()),
None => {
log::warn!("[host] no message");
return Ok(());
}
};
let slice: &[u8] = &bytes;
ciborium::from_reader(slice)?
};
if let HostMessage::Echo(echo) = &msg {
self.send_message(echo).await.log_warn();
return Ok(());
}
log::debug!(
"{} {}",
"[host::incoming::message]".bold(),
format!("{msg:?}").dimmed()
);
Ok(self.host_send.send(msg).await?)
}
async fn send_message(&mut self, msg: &ServerToHostMessage) -> Result<(), anyhow::Error> {
Ok(self
.socket
.send(
#[cfg(not(feature = "cbor"))]
ws::Message::Text(serde_json::to_string(msg)?.into()),
#[cfg(feature = "cbor")]
ws::Message::Binary({
let mut bytes = Vec::new();
ciborium::into_writer(msg, &mut bytes)?;
bytes.into()
}),
)
.await?)
}
pub async fn run(mut self) {
loop {
tokio::select! {
msg = self.socket.recv() => {
if let Err(err) = self.on_recv(msg).await {
log::error!("{} {err}", "[host::incoming]".bold());
return;
}
},
msg = self.server_recv.recv() => {
match msg {
Ok(msg) => {
log::debug!("sending message to host: {}", format!("{msg:?}").dimmed());
if let Err(err) = self.send_message(&msg).await {
log::error!("{} {err}", "[host::outgoing]".bold())
}
},
Err(err) => {
log::error!("{} {err}", "[host::mpsc]".bold());
return;
}
}
}
};
}
}
}