werewolves/werewolves-server/src/main.rs

268 lines
7.6 KiB
Rust

mod client;
mod communication;
mod connection;
mod game;
mod host;
mod lobby;
mod runner;
mod saver;
use axum::{
Router,
http::{Request, header},
response::IntoResponse,
routing::{any, get},
};
use axum_extra::headers;
use communication::lobby::LobbyComms;
use connection::JoinedPlayers;
use core::{fmt::Display, net::SocketAddr, str::FromStr};
use log::Record;
use runner::IdentifiedClientMessage;
use std::{env, io::Write, path::Path};
use tokio::{
sync::{broadcast, mpsc},
time::Instant,
};
use crate::{
communication::{Comms, host::HostComms, player::PlayerIdComms},
saver::FileSaver,
};
const DEFAULT_PORT: u16 = 8080;
const DEFAULT_HOST: &str = "127.0.0.1";
const DEFAULT_SAVE_DIR: &str = "werewolves-saves/";
#[tokio::main]
async fn main() {
// pretty_env_logger::init();
use colored::Colorize;
pretty_env_logger::formatted_builder()
.parse_default_env()
.format(|f, record| {
let time = chrono::Local::now().time().to_string().dimmed();
match record.file() {
Some(file) => {
let file = format!(
"[{file}{}]",
record
.line()
.map(|l| format!(":{l}"))
.unwrap_or_else(String::new),
)
.dimmed();
let level = match record.level() {
log::Level::Error => "[err]".red().bold(),
log::Level::Warn => "[warn]".yellow().bold(),
log::Level::Info => "[info]".white().bold(),
log::Level::Debug => "[debug]".dimmed().bold(),
log::Level::Trace => "[trace]".dimmed(),
};
let args = record.args();
let arrow = "".bold().magenta();
writeln!(
f,
"{time} {file}\n{level} {arrow} {args}",
// "⇗⇘⇗⇘⇗⇘".bold().dimmed(),
)
}
_ => writeln!(f, "{time} [{}] {}", record.level(), record.args()),
}
})
.try_init()
.unwrap();
let default_panic = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
default_panic(info);
std::process::exit(1);
}));
let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string());
let port = env::var("PORT")
.map_err(|err| anyhow::anyhow!("{err}"))
.map(|port_str| {
port_str
.parse::<u16>()
.unwrap_or_else(|err| panic!("parse PORT={port_str} failed: {err}"))
})
.unwrap_or(DEFAULT_PORT);
let listen_addr =
SocketAddr::from_str(format!("{host}:{port}").as_str()).expect("invalid host/port");
let (send, recv) = broadcast::channel(100);
let (server_send, host_recv) = broadcast::channel(100);
let (host_send, server_recv) = mpsc::channel(100);
let (connect_send, connect_recv) = broadcast::channel(100);
let joined_players = JoinedPlayers::new(connect_send);
let lobby_comms = LobbyComms::new(
Comms::new(
HostComms::new(server_send, server_recv),
PlayerIdComms::new(joined_players.clone(), recv, connect_recv.resubscribe()),
),
connect_recv,
);
let jp_clone = joined_players.clone();
let path = Path::new(option_env!("SAVE_PATH").unwrap_or(DEFAULT_SAVE_DIR))
.canonicalize()
.expect("canonicalizing path");
if let Err(err) = std::fs::create_dir(&path)
&& !matches!(err.kind(), std::io::ErrorKind::AlreadyExists)
{
panic!("creating save dir at [{path:?}]: {err}")
}
// Check if we can write to the path
{
let test_file_path = path.join(".test");
if let Err(err) = std::fs::File::create(&test_file_path) {
panic!("can't create files in {path:?}: {err}")
}
std::fs::remove_file(&test_file_path).log_err();
}
let saver = FileSaver::new(path);
tokio::spawn(async move {
crate::runner::run_game(jp_clone, lobby_comms, saver).await;
panic!("game over");
});
let state = AppState {
joined_players,
host_recv,
host_send,
send,
};
let app = Router::new()
.route("/connect/client", any(client::handler))
.route("/connect/host", any(host::handler))
.with_state(state)
.fallback(get(handle_http_static));
let listener = tokio::net::TcpListener::bind(listen_addr).await.unwrap();
log::info!("listening on {}", listener.local_addr().unwrap());
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
}
struct AppState {
joined_players: JoinedPlayers,
send: broadcast::Sender<IdentifiedClientMessage>,
host_send: tokio::sync::mpsc::Sender<werewolves_proto::message::host::HostMessage>,
host_recv: broadcast::Receiver<werewolves_proto::message::host::ServerToHostMessage>,
}
impl Clone for AppState {
fn clone(&self) -> Self {
Self {
joined_players: self.joined_players.clone(),
send: self.send.clone(),
host_send: self.host_send.clone(),
host_recv: self.host_recv.resubscribe(),
}
}
}
async fn handle_http_static(req: Request<axum::body::Body>) -> impl IntoResponse {
use mime_sniffer::MimeTypeSniffer;
const INDEX_FILE: &[u8] = include_bytes!("../../werewolves/dist/index.html");
let path = req.uri().path();
werewolves_macros::include_dist!(DIST_FILES, "werewolves/dist");
let file = if let Some(file) = DIST_FILES.iter().find_map(|(file_path, file)| {
if *file_path == path {
Some(*file)
} else {
None
}
}) {
file
} else {
return (
[(header::CONTENT_TYPE, "text/html".to_string())],
INDEX_FILE,
);
};
let mime = if path.ends_with(".js") {
"text/javascript".to_string()
} else if path.ends_with(".css") {
"text/css".to_string()
} else if path.ends_with(".wasm") {
"application/wasm".to_string()
} else {
file.sniff_mime_type()
.unwrap_or("application/octet-stream")
.to_string()
};
([(header::CONTENT_TYPE, mime)], file)
}
struct XForwardedFor(String);
impl Display for XForwardedFor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.0.as_str())
}
}
impl headers::Header for XForwardedFor {
fn name() -> &'static header::HeaderName {
static NAME: header::HeaderName = header::HeaderName::from_static("x-forwarded-for");
&NAME
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i header::HeaderValue>,
{
Ok(Self(
values
.next()
.and_then(|v| v.to_str().ok())
.ok_or(headers::Error::invalid())?
.to_string(),
))
}
fn encode<E: Extend<header::HeaderValue>>(&self, _: &mut E) {}
}
pub trait LogError {
fn log_warn(self);
fn log_err(self);
fn log_debug(self);
}
impl<T, E> LogError for Result<T, E>
where
E: Display,
{
fn log_warn(self) {
if let Err(err) = self {
log::warn!("{err}");
}
}
fn log_err(self) {
if let Err(err) = self {
log::error!("{err}");
}
}
fn log_debug(self) {
if let Err(err) = self {
log::debug!("{err}")
}
}
}