werewolves/werewolves-server/src/main.rs

296 lines
8.6 KiB
Rust

mod client;
mod communication;
mod connection;
mod game;
mod host;
mod lobby;
mod runner;
mod saver;
use axum::{
Router,
http::{Request, StatusCode, header},
response::IntoResponse,
routing::{any, get},
};
use axum_extra::headers;
use communication::lobby::LobbyComms;
use connection::JoinedPlayers;
use core::{fmt::Display, net::SocketAddr, str::FromStr};
use fast_qr::convert::{Builder, Shape, svg::SvgBuilder};
use runner::IdentifiedClientMessage;
use std::{env, io::Write, path::Path};
use tokio::sync::{broadcast, mpsc};
use crate::{
communication::{Comms, connect::ConnectUpdate, host::HostComms, player::PlayerIdComms},
saver::FileSaver,
};
const DEFAULT_PORT: u16 = 8080;
const DEFAULT_HOST: &str = "127.0.0.1";
const DEFAULT_SAVE_DIR: &str = "werewolves-saves/";
const DEFAULT_QRCODE_URL: &str = "https://wolf.emilis.dev/";
#[tokio::main]
async fn main() {
// pretty_env_logger::init();
use colored::Colorize;
pretty_env_logger::formatted_builder()
.parse_default_env()
.format(|f, record| {
let time = chrono::Local::now().time().to_string().dimmed();
match record.file() {
Some(file) => {
let file = format!(
"[{file}{}]",
record
.line()
.map(|l| format!(":{l}"))
.unwrap_or_else(String::new),
)
.dimmed();
let level = match record.level() {
log::Level::Error => "[err]".red().bold(),
log::Level::Warn => "[warn]".yellow().bold(),
log::Level::Info => "[info]".white().bold(),
log::Level::Debug => "[debug]".dimmed().bold(),
log::Level::Trace => "[trace]".dimmed(),
};
let args = record.args();
let arrow = "".bold().magenta();
writeln!(
f,
"{time} {file}\n{level} {arrow} {args}",
// "⇗⇘⇗⇘⇗⇘".bold().dimmed(),
)
}
_ => writeln!(f, "{time} [{}] {}", record.level(), record.args()),
}
})
.try_init()
.unwrap();
let default_panic = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
default_panic(info);
std::process::exit(1);
}));
let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string());
let port = env::var("PORT")
.map_err(|err| anyhow::anyhow!("{err}"))
.map(|port_str| {
port_str
.parse::<u16>()
.unwrap_or_else(|err| panic!("parse PORT={port_str} failed: {err}"))
})
.unwrap_or(DEFAULT_PORT);
let listen_addr =
SocketAddr::from_str(format!("{host}:{port}").as_str()).expect("invalid host/port");
let (send, recv) = broadcast::channel(100);
let (server_send, host_recv) = broadcast::channel(100);
let (host_send, server_recv) = mpsc::channel(100);
let conn_update = ConnectUpdate::new();
let joined_players = JoinedPlayers::new(conn_update.clone());
let lobby_comms = LobbyComms::new(
Comms::new(
HostComms::new(server_send, server_recv),
PlayerIdComms::new(recv),
),
conn_update,
);
let jp_clone = joined_players.clone();
let path = Path::new(option_env!("SAVE_PATH").unwrap_or(DEFAULT_SAVE_DIR));
if let Err(err) = std::fs::create_dir(path)
&& !matches!(err.kind(), std::io::ErrorKind::AlreadyExists)
{
panic!("creating save dir at [{path:?}]: {err}")
}
// Check if we can write to the path
{
let test_file_path = path.join(".test");
if let Err(err) = std::fs::File::create(&test_file_path) {
panic!("can't create files in {path:?}: {err}")
}
std::fs::remove_file(&test_file_path).log_err();
}
let saver = FileSaver::new(path.canonicalize().expect("canonicalizing path"));
tokio::spawn(async move {
crate::runner::run_game(jp_clone, lobby_comms, saver).await;
panic!("game over");
});
let state = AppState {
joined_players,
host_recv,
host_send,
send,
};
let app = Router::new()
.route("/connect/client", any(client::handler))
.route("/connect/host", any(host::handler))
.route("/qrcode", get(handle_qr_code))
.with_state(state)
.fallback(get(handle_http_static));
let listener = tokio::net::TcpListener::bind(listen_addr).await.unwrap();
log::info!("listening on {}", listener.local_addr().unwrap());
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
}
struct AppState {
joined_players: JoinedPlayers,
send: broadcast::Sender<IdentifiedClientMessage>,
host_send: tokio::sync::mpsc::Sender<werewolves_proto::message::host::HostMessage>,
host_recv: broadcast::Receiver<werewolves_proto::message::host::ServerToHostMessage>,
}
impl Clone for AppState {
fn clone(&self) -> Self {
Self {
joined_players: self.joined_players.clone(),
send: self.send.clone(),
host_send: self.host_send.clone(),
host_recv: self.host_recv.resubscribe(),
}
}
}
async fn handle_qr_code() -> impl IntoResponse {
const QRCODE: &str = const {
match option_env!("QRCODE_URL") {
Some(qrcode) => qrcode,
None => DEFAULT_QRCODE_URL,
}
};
const EMPTY: &[u8] = &[];
let qr_str = match fast_qr::QRBuilder::new(QRCODE).build() {
Ok(qr) => SvgBuilder::default().shape(Shape::Square).to_str(&qr),
Err(err) => {
log::error!("generating qr code from [{QRCODE}]: {err}");
return (
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, "application/octet-stream")],
EMPTY.to_vec(),
);
}
};
(
StatusCode::OK,
[(header::CONTENT_TYPE, "image/svg+xml")],
qr_str.as_bytes().to_vec(),
)
}
async fn handle_http_static(req: Request<axum::body::Body>) -> impl IntoResponse {
use mime_sniffer::MimeTypeSniffer;
const INDEX_FILE: &[u8] = include_bytes!("../../werewolves/dist/index.html");
let path = req.uri().path();
werewolves_macros::include_dist!(DIST_FILES, "werewolves/dist");
let file = if let Some(file) = DIST_FILES.iter().find_map(|(file_path, file)| {
if *file_path == path {
Some(*file)
} else {
None
}
}) {
file
} else {
return (
[(header::CONTENT_TYPE, "text/html".to_string())],
INDEX_FILE,
);
};
let mime = if path.ends_with(".js") {
"text/javascript".to_string()
} else if path.ends_with(".css") {
"text/css".to_string()
} else if path.ends_with(".wasm") {
"application/wasm".to_string()
} else if path.ends_with(".svg") {
"image/svg+xml".to_string()
} else {
file.sniff_mime_type()
.unwrap_or("application/octet-stream")
.to_string()
};
([(header::CONTENT_TYPE, mime)], file)
}
struct XForwardedFor(String);
impl Display for XForwardedFor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.0.as_str())
}
}
impl headers::Header for XForwardedFor {
fn name() -> &'static header::HeaderName {
static NAME: header::HeaderName = header::HeaderName::from_static("x-forwarded-for");
&NAME
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i header::HeaderValue>,
{
Ok(Self(
values
.next()
.and_then(|v| v.to_str().ok())
.ok_or(headers::Error::invalid())?
.to_string(),
))
}
fn encode<E: Extend<header::HeaderValue>>(&self, _: &mut E) {}
}
pub trait LogError {
fn log_warn(self);
fn log_err(self);
fn log_debug(self);
}
impl<T, E> LogError for Result<T, E>
where
E: Display,
{
fn log_warn(self) {
if let Err(err) = self {
log::warn!("{err}");
}
}
fn log_err(self) {
if let Err(err) = self {
log::error!("{err}");
}
}
fn log_debug(self) {
if let Err(err) = self {
log::debug!("{err}")
}
}
}