use core::net::SocketAddr; use axum::{ extract::{ ConnectInfo, State, WebSocketUpgrade, ws::{self, Message, WebSocket}, }, response::IntoResponse, }; use axum_extra::{TypedHeader, headers}; use colored::Colorize; use tokio::sync::{broadcast::Receiver, mpsc::Sender}; use werewolves_proto::message::host::{HostMessage, ServerToHostMessage}; use crate::{AppState, LogError, XForwardedFor}; pub async fn handler( ws: WebSocketUpgrade, user_agent: Option>, x_forwarded_for: Option>, ConnectInfo(addr): ConnectInfo, State(state): State, ) -> impl IntoResponse { let who = x_forwarded_for .map(|x| x.to_string()) .unwrap_or_else(|| addr.to_string()); log::info!( "{who}{} connected.", user_agent .map(|agent| format!(" (User-Agent: {})", agent.as_str())) .unwrap_or_default(), ); // finalize the upgrade process by returning upgrade callback. // we can customize the callback by sending additional info such as address. ws.on_upgrade(move |socket| async move { Host::new( socket, state.host_send.clone(), state.host_recv.resubscribe(), ) .run() .await }) } struct Host { socket: WebSocket, host_send: Sender, server_recv: Receiver, } impl Host { pub fn new( socket: WebSocket, host_send: Sender, server_recv: Receiver, ) -> Self { Self { host_send, server_recv, socket, } } async fn on_recv( &mut self, msg: Option>, ) -> Result<(), anyhow::Error> { #[cfg(not(feature = "cbor"))] let msg: HostMessage = serde_json::from_slice( &match msg { Some(Ok(msg)) => msg, Some(Err(err)) => return Err(err.into()), None => { log::warn!("[host] no message"); return Ok(()); } } .into_data(), )?; #[cfg(feature = "cbor")] let msg: HostMessage = { let bytes = match msg { Some(Ok(msg)) => msg.into_data(), Some(Err(err)) => return Err(err.into()), None => { log::warn!("[host] no message"); return Ok(()); } }; let slice: &[u8] = &bytes; ciborium::from_reader(slice)? }; if let HostMessage::Echo(echo) = &msg { self.send_message(echo).await.log_warn(); return Ok(()); } log::debug!( "{} {}", "[host::incoming::message]".bold(), format!("{msg:?}").dimmed() ); Ok(self.host_send.send(msg).await?) } async fn send_message(&mut self, msg: &ServerToHostMessage) -> Result<(), anyhow::Error> { Ok(self .socket .send( #[cfg(not(feature = "cbor"))] ws::Message::Text(serde_json::to_string(msg)?.into()), #[cfg(feature = "cbor")] ws::Message::Binary({ let mut bytes = Vec::new(); ciborium::into_writer(msg, &mut bytes)?; bytes.into() }), ) .await?) } pub async fn run(mut self) { loop { tokio::select! { msg = self.socket.recv() => { if let Err(err) = self.on_recv(msg).await { log::error!("{} {err}", "[host::incoming]".bold()); return; } }, msg = self.server_recv.recv() => { match msg { Ok(msg) => { log::debug!("sending message to host: {}", format!("{msg:?}").dimmed()); if let Err(err) = self.send_message(&msg).await { log::error!("{} {err}", "[host::outgoing]".bold()) } }, Err(err) => { log::error!("{} {err}", "[host::mpsc]".bold()); return; } } } }; } } }