werewolves/werewolves-server/src/host.rs

159 lines
5.0 KiB
Rust

// Copyright (C) 2025 Emilis Bliūdžius
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use core::net::SocketAddr;
use axum::{
extract::{
ConnectInfo, State, WebSocketUpgrade,
ws::{self, Message, WebSocket},
},
response::IntoResponse,
};
use axum_extra::{TypedHeader, headers};
use colored::Colorize;
use tokio::sync::{broadcast::Receiver, mpsc::Sender};
use werewolves_proto::message::host::{HostMessage, ServerToHostMessage};
use crate::{AppState, LogError, XForwardedFor};
pub async fn handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
x_forwarded_for: Option<TypedHeader<XForwardedFor>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<AppState>,
) -> impl IntoResponse {
let who = x_forwarded_for
.map(|x| x.to_string())
.unwrap_or_else(|| addr.to_string());
log::info!(
"{who}{} connected.",
user_agent
.map(|agent| format!(" (User-Agent: {})", agent.as_str()))
.unwrap_or_default(),
);
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| async move {
Host::new(
socket,
state.host_send.clone(),
state.host_recv.resubscribe(),
)
.run()
.await
})
}
struct Host {
socket: WebSocket,
host_send: Sender<HostMessage>,
server_recv: Receiver<ServerToHostMessage>,
}
impl Host {
pub fn new(
socket: WebSocket,
host_send: Sender<HostMessage>,
server_recv: Receiver<ServerToHostMessage>,
) -> Self {
Self {
host_send,
server_recv,
socket,
}
}
async fn on_recv(
&mut self,
msg: Option<Result<Message, axum::Error>>,
) -> Result<(), anyhow::Error> {
#[cfg(not(feature = "cbor"))]
let msg: HostMessage = serde_json::from_slice(
&match msg {
Some(Ok(msg)) => msg,
Some(Err(err)) => return Err(err.into()),
None => {
log::warn!("[host] no message");
return Ok(());
}
}
.into_data(),
)?;
#[cfg(feature = "cbor")]
let msg: HostMessage = {
let bytes = match msg {
Some(Ok(msg)) => msg.into_data(),
Some(Err(err)) => return Err(err.into()),
None => {
log::warn!("[host] no message");
return Ok(());
}
};
let slice: &[u8] = &bytes;
ciborium::from_reader(slice)?
};
log::debug!(
"{} {}",
"[host::incoming::message]".bold(),
format!("{msg:?}").dimmed()
);
Ok(self.host_send.send(msg).await?)
}
async fn send_message(&mut self, msg: &ServerToHostMessage) -> Result<(), anyhow::Error> {
Ok(self
.socket
.send(
#[cfg(not(feature = "cbor"))]
ws::Message::Text(serde_json::to_string(msg)?.into()),
#[cfg(feature = "cbor")]
ws::Message::Binary({
let mut bytes = Vec::new();
ciborium::into_writer(msg, &mut bytes)?;
bytes.into()
}),
)
.await?)
}
pub async fn run(mut self) {
loop {
tokio::select! {
msg = self.socket.recv() => {
if let Err(err) = self.on_recv(msg).await {
log::error!("{} {err}", "[host::incoming]".bold());
return;
}
},
msg = self.server_recv.recv() => {
match msg {
Ok(msg) => {
if let Err(err) = self.send_message(&msg).await {
log::error!("{} {err}", "[host::outgoing]".bold())
}
},
Err(err) => {
log::error!("{} {err}", "[host::mpsc]".bold());
return;
}
}
}
};
}
}
}