150 lines
4.5 KiB
Rust
150 lines
4.5 KiB
Rust
use core::net::SocketAddr;
|
|
|
|
use axum::{
|
|
extract::{
|
|
ConnectInfo, State, WebSocketUpgrade,
|
|
ws::{self, Message, WebSocket},
|
|
},
|
|
response::IntoResponse,
|
|
};
|
|
use axum_extra::{TypedHeader, headers};
|
|
use colored::Colorize;
|
|
use tokio::sync::{broadcast::Receiver, mpsc::Sender};
|
|
use werewolves_proto::message::host::{HostMessage, ServerToHostMessage};
|
|
|
|
use crate::{AppState, LogError, XForwardedFor};
|
|
|
|
pub async fn handler(
|
|
ws: WebSocketUpgrade,
|
|
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
|
x_forwarded_for: Option<TypedHeader<XForwardedFor>>,
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
State(state): State<AppState>,
|
|
) -> impl IntoResponse {
|
|
let who = x_forwarded_for
|
|
.map(|x| x.to_string())
|
|
.unwrap_or_else(|| addr.to_string());
|
|
log::info!(
|
|
"{who}{} connected.",
|
|
user_agent
|
|
.map(|agent| format!(" (User-Agent: {})", agent.as_str()))
|
|
.unwrap_or_default(),
|
|
);
|
|
|
|
// finalize the upgrade process by returning upgrade callback.
|
|
// we can customize the callback by sending additional info such as address.
|
|
ws.on_upgrade(move |socket| async move {
|
|
Host::new(
|
|
socket,
|
|
state.host_send.clone(),
|
|
state.host_recv.resubscribe(),
|
|
)
|
|
.run()
|
|
.await
|
|
})
|
|
}
|
|
|
|
struct Host {
|
|
socket: WebSocket,
|
|
host_send: Sender<HostMessage>,
|
|
server_recv: Receiver<ServerToHostMessage>,
|
|
}
|
|
|
|
impl Host {
|
|
pub fn new(
|
|
socket: WebSocket,
|
|
host_send: Sender<HostMessage>,
|
|
server_recv: Receiver<ServerToHostMessage>,
|
|
) -> Self {
|
|
Self {
|
|
host_send,
|
|
server_recv,
|
|
socket,
|
|
}
|
|
}
|
|
|
|
async fn on_recv(
|
|
&mut self,
|
|
msg: Option<Result<Message, axum::Error>>,
|
|
) -> Result<(), anyhow::Error> {
|
|
#[cfg(not(feature = "cbor"))]
|
|
let msg: HostMessage = serde_json::from_slice(
|
|
&match msg {
|
|
Some(Ok(msg)) => msg,
|
|
Some(Err(err)) => return Err(err.into()),
|
|
None => {
|
|
log::warn!("[host] no message");
|
|
return Ok(());
|
|
}
|
|
}
|
|
.into_data(),
|
|
)?;
|
|
#[cfg(feature = "cbor")]
|
|
let msg: HostMessage = {
|
|
let bytes = match msg {
|
|
Some(Ok(msg)) => msg.into_data(),
|
|
Some(Err(err)) => return Err(err.into()),
|
|
None => {
|
|
log::warn!("[host] no message");
|
|
return Ok(());
|
|
}
|
|
};
|
|
let slice: &[u8] = &bytes;
|
|
ciborium::from_reader(slice)?
|
|
};
|
|
if let HostMessage::Echo(echo) = &msg {
|
|
self.send_message(echo).await.log_warn();
|
|
return Ok(());
|
|
}
|
|
log::debug!(
|
|
"{} {}",
|
|
"[host::incoming::message]".bold(),
|
|
format!("{msg:?}").dimmed()
|
|
);
|
|
Ok(self.host_send.send(msg).await?)
|
|
}
|
|
|
|
async fn send_message(&mut self, msg: &ServerToHostMessage) -> Result<(), anyhow::Error> {
|
|
Ok(self
|
|
.socket
|
|
.send(
|
|
#[cfg(not(feature = "cbor"))]
|
|
ws::Message::Text(serde_json::to_string(msg)?.into()),
|
|
#[cfg(feature = "cbor")]
|
|
ws::Message::Binary({
|
|
let mut bytes = Vec::new();
|
|
ciborium::into_writer(msg, &mut bytes)?;
|
|
bytes.into()
|
|
}),
|
|
)
|
|
.await?)
|
|
}
|
|
|
|
pub async fn run(mut self) {
|
|
loop {
|
|
tokio::select! {
|
|
msg = self.socket.recv() => {
|
|
if let Err(err) = self.on_recv(msg).await {
|
|
log::error!("{} {err}", "[host::incoming]".bold());
|
|
return;
|
|
}
|
|
},
|
|
msg = self.server_recv.recv() => {
|
|
match msg {
|
|
Ok(msg) => {
|
|
log::debug!("sending message to host: {}", format!("{msg:?}").dimmed());
|
|
if let Err(err) = self.send_message(&msg).await {
|
|
log::error!("{} {err}", "[host::outgoing]".bold())
|
|
}
|
|
},
|
|
Err(err) => {
|
|
log::error!("{} {err}", "[host::mpsc]".bold());
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
}
|
|
}
|
|
}
|