wip persistence
This commit is contained in:
parent
15a6454ae2
commit
e91a019872
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,57 @@
|
|||
drop table if exists users cascade;
|
||||
create table users (
|
||||
id uuid not null default gen_random_uuid() primary key,
|
||||
name text,
|
||||
username text not null,
|
||||
password_hash text not null,
|
||||
|
||||
created_at timestamp with time zone not null,
|
||||
updated_at timestamp with time zone not null,
|
||||
|
||||
check (created_at <= updated_at)
|
||||
);
|
||||
drop index if exists users_username_idx;
|
||||
create index users_username_idx on users (username);
|
||||
drop index if exists users_username_unique;
|
||||
create unique index users_username_unique on users (lower(username));
|
||||
|
||||
drop table if exists login_tokens cascade;
|
||||
create table login_tokens (
|
||||
token text not null primary key,
|
||||
user_id uuid not null references users(id),
|
||||
created_at timestamp with time zone not null,
|
||||
expires_at timestamp with time zone not null,
|
||||
|
||||
check (created_at < expires_at)
|
||||
);
|
||||
|
||||
drop type if exists game_outcome cascade;
|
||||
create type game_outcome as enum (
|
||||
'village_victory',
|
||||
'wolves_victory'
|
||||
);
|
||||
|
||||
drop table if exists games cascade;
|
||||
create table games (
|
||||
id uuid not null primary key,
|
||||
outcome game_outcome,
|
||||
state json not null,
|
||||
story json not null,
|
||||
|
||||
started_at timestamp with time zone not null,
|
||||
updated_at timestamp with time zone not null default now()
|
||||
);
|
||||
|
||||
drop table if exists players;
|
||||
create table players (
|
||||
id uuid not null primary key,
|
||||
user_id uuid references users(id)
|
||||
);
|
||||
|
||||
drop table if exists game_players;
|
||||
create table game_players (
|
||||
game_id uuid not null references games(id),
|
||||
player_id uuid not null references players(id),
|
||||
|
||||
primary key (game_id, player_id)
|
||||
);
|
||||
|
|
@ -11,8 +11,25 @@ serde = { version = "1.0", features = ["derive"] }
|
|||
uuid = { version = "1.17", features = ["v4", "serde"] }
|
||||
rand = { version = "0.9" }
|
||||
werewolves-macros = { path = "../werewolves-macros" }
|
||||
axum = { version = "*", optional = true }
|
||||
argon2 = { version = "*", optional = true }
|
||||
sqlx = { version = "*", optional = true }
|
||||
ciborium = { version = "*", optional = true }
|
||||
bytes = { version = "1.10.1", features = ["serde"], optional = true }
|
||||
axum-extra = { version = "*", optional = true }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { version = "1" }
|
||||
pretty_env_logger = { version = "0.5" }
|
||||
colored = { version = "3.0" }
|
||||
|
||||
[features]
|
||||
server = [
|
||||
"dep:axum",
|
||||
"dep:sqlx",
|
||||
"dep:argon2",
|
||||
"dep:ciborium",
|
||||
"dep:bytes",
|
||||
"dep:axum-extra",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
use axum::{
|
||||
body::Bytes,
|
||||
extract::{FromRequest, Request, rejection::BytesRejection},
|
||||
http::{HeaderMap, HeaderValue, StatusCode, header},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use axum_extra::headers::Mime;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use core::fmt::Display;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
const CBOR_CONTENT_TYPE: &str = "application/cbor";
|
||||
const PLAIN_CONTENT_TYPE: &str = "text/plain";
|
||||
|
||||
#[must_use]
|
||||
pub struct Cbor<T>(pub T);
|
||||
|
||||
impl<T> Cbor<T> {
|
||||
pub const fn new(t: T) -> Self {
|
||||
Self(t)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> FromRequest<S> for Cbor<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = CborRejection;
|
||||
|
||||
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
||||
if !cbor_content_type(req.headers()) {
|
||||
return Err(CborRejection::MissingCborContentType);
|
||||
}
|
||||
|
||||
let bytes = Bytes::from_request(req, state).await?;
|
||||
Ok(Self(ciborium::from_reader::<T, _>(&*bytes)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoResponse for Cbor<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
// Extracted into separate fn so it's only compiled once for all T.
|
||||
fn make_response(buf: BytesMut, ser_result: Result<(), CborRejection>) -> Response {
|
||||
match ser_result {
|
||||
Ok(()) => (
|
||||
[(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static(CBOR_CONTENT_TYPE),
|
||||
)],
|
||||
buf.freeze(),
|
||||
)
|
||||
.into_response(),
|
||||
Err(err) => err.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
// Use a small initial capacity of 128 bytes like serde_json::to_vec
|
||||
// https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
|
||||
let mut buf = BytesMut::with_capacity(128).writer();
|
||||
let res = ciborium::into_writer(&self.0, &mut buf)
|
||||
.map_err(|err| CborRejection::SerdeRejection(err.to_string()));
|
||||
make_response(buf.into_inner(), res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CborRejection {
|
||||
MissingCborContentType,
|
||||
BytesRejection(BytesRejection),
|
||||
DeserializeRejection(String),
|
||||
SerdeRejection(String),
|
||||
}
|
||||
impl<T: Display> From<ciborium::de::Error<T>> for CborRejection {
|
||||
fn from(value: ciborium::de::Error<T>) -> Self {
|
||||
Self::SerdeRejection(match value {
|
||||
ciborium::de::Error::Io(err) => format!("i/o: {err}"),
|
||||
ciborium::de::Error::Syntax(offset) => format!("syntax error at {offset}"),
|
||||
ciborium::de::Error::Semantic(offset, err) => format!(
|
||||
"semantic parse: {err}{}",
|
||||
offset
|
||||
.map(|offset| format!(" at {offset}"))
|
||||
.unwrap_or_default(),
|
||||
),
|
||||
ciborium::de::Error::RecursionLimitExceeded => {
|
||||
String::from("the input caused serde to recurse too much")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BytesRejection> for CborRejection {
|
||||
fn from(value: BytesRejection) -> Self {
|
||||
Self::BytesRejection(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for CborRejection {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
CborRejection::MissingCborContentType => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
[(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||||
)],
|
||||
String::from("missing cbor content type"),
|
||||
),
|
||||
CborRejection::BytesRejection(err) => (
|
||||
err.status(),
|
||||
[(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||||
)],
|
||||
format!("bytes rejection: {}", err.body_text()),
|
||||
),
|
||||
CborRejection::SerdeRejection(err) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
[(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||||
)],
|
||||
err,
|
||||
),
|
||||
CborRejection::DeserializeRejection(err) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
[(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
|
||||
)],
|
||||
err,
|
||||
),
|
||||
}
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn cbor_content_type(headers: &HeaderMap) -> bool {
|
||||
let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let Ok(content_type) = content_type.to_str() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let Ok(mime) = content_type.parse::<Mime>() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
mime.type_() == "application"
|
||||
&& (mime.subtype() == "cbor" || mime.suffix().is_some_and(|name| name == "cbor"))
|
||||
}
|
||||
|
|
@ -1,3 +1,6 @@
|
|||
#[cfg(feature = "server")]
|
||||
use core::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
|
|
@ -81,4 +84,157 @@ pub enum GameError {
|
|||
MissingTime(GameTime),
|
||||
#[error("no previous during day")]
|
||||
NoPreviousDuringDay,
|
||||
#[error("server error: {0}")]
|
||||
ServerError(#[from] ServerError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Error, Serialize, Deserialize)]
|
||||
pub enum DatabaseError {
|
||||
#[error("user already exists")]
|
||||
UserAlreadyExists,
|
||||
#[error("password hashing error: {0}")]
|
||||
PasswordHashError(String),
|
||||
#[error("sqlx error: {0}")]
|
||||
SqlxError(String),
|
||||
#[error("not found")]
|
||||
NotFound,
|
||||
#[error("serde_json: {0}")]
|
||||
SerdeJson(String),
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl From<serde_json::Error> for DatabaseError {
|
||||
fn from(value: serde_json::Error) -> Self {
|
||||
Self::SerdeJson(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl axum::response::IntoResponse for DatabaseError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
use axum::http::StatusCode;
|
||||
|
||||
use crate::cbor::Cbor;
|
||||
|
||||
(
|
||||
match self {
|
||||
DatabaseError::UserAlreadyExists => StatusCode::BAD_REQUEST,
|
||||
DatabaseError::NotFound => StatusCode::NOT_FOUND,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
},
|
||||
Cbor(self),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl From<sqlx::Error> for DatabaseError {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
match err {
|
||||
sqlx::Error::RowNotFound => Self::NotFound,
|
||||
_ => Self::SqlxError(err.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl From<argon2::password_hash::Error> for DatabaseError {
|
||||
fn from(err: argon2::password_hash::Error) -> Self {
|
||||
Self::PasswordHashError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Error, Serialize, Deserialize)]
|
||||
pub enum ServerError {
|
||||
#[error("database error: {0}")]
|
||||
DatabaseError(DatabaseError),
|
||||
#[error("invalid credentials")]
|
||||
InvalidCredentials,
|
||||
#[error("token expired")]
|
||||
ExpiredToken,
|
||||
#[error("internal server error: {0}")]
|
||||
InternalServerError(String),
|
||||
#[error("connection error")]
|
||||
ConnectionError,
|
||||
#[error("invalid request: {0}")]
|
||||
InvalidRequest(String),
|
||||
#[error("not found")]
|
||||
NotFound,
|
||||
}
|
||||
|
||||
impl<I: Into<DatabaseError>> From<I> for ServerError {
|
||||
fn from(value: I) -> Self {
|
||||
let database_err: DatabaseError = value.into();
|
||||
if let DatabaseError::NotFound = &database_err {
|
||||
return Self::NotFound;
|
||||
}
|
||||
Self::DatabaseError(database_err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GameError> for ServerError {
|
||||
fn from(value: GameError) -> Self {
|
||||
Self::InvalidRequest(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl<T: Display> From<ciborium::de::Error<T>> for ServerError {
|
||||
fn from(_: ciborium::de::Error<T>) -> Self {
|
||||
Self::InvalidRequest(String::from("could not decode request"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl axum::response::IntoResponse for ServerError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
use axum::http::StatusCode;
|
||||
|
||||
use crate::cbor::Cbor;
|
||||
|
||||
match self {
|
||||
ServerError::ExpiredToken => {
|
||||
(StatusCode::UNAUTHORIZED, Cbor(ServerError::ExpiredToken)).into_response()
|
||||
}
|
||||
ServerError::NotFound | ServerError::DatabaseError(DatabaseError::NotFound) => {
|
||||
(StatusCode::NOT_FOUND, Cbor(ServerError::NotFound)).into_response()
|
||||
}
|
||||
ServerError::DatabaseError(DatabaseError::UserAlreadyExists) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Cbor(ServerError::InvalidRequest(String::from("username taken"))),
|
||||
)
|
||||
.into_response(),
|
||||
ServerError::DatabaseError(err) => {
|
||||
use uuid::Uuid;
|
||||
|
||||
let error_id = Uuid::new_v4();
|
||||
log::error!("database error[{error_id}]: {err}");
|
||||
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Cbor(ServerError::InternalServerError(format!(
|
||||
"internal server error. error id: {error_id}"
|
||||
))),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
ServerError::InvalidCredentials => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Cbor(ServerError::InvalidCredentials),
|
||||
)
|
||||
.into_response(),
|
||||
ServerError::InternalServerError(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Cbor(self)).into_response()
|
||||
}
|
||||
ServerError::ConnectionError => {
|
||||
(StatusCode::BAD_REQUEST, Cbor(ServerError::ConnectionError)).into_response()
|
||||
}
|
||||
ServerError::InvalidRequest(reason) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Cbor(ServerError::InvalidRequest(reason)),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ use core::{
|
|||
ops::{Deref, Range, RangeBounds},
|
||||
};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::{Rng, seq::SliceRandom};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ use crate::{
|
|||
night::{Night, ServerAction},
|
||||
story::{DayDetail, GameActions, GameStory, NightDetails},
|
||||
},
|
||||
id::GameId,
|
||||
message::{
|
||||
CharacterState, Identification,
|
||||
host::{HostDayMessage, HostGameMessage, HostNightMessage, ServerToHostMessage},
|
||||
|
|
@ -36,6 +38,8 @@ type Result<T> = core::result::Result<T, GameError>;
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Game {
|
||||
id: GameId,
|
||||
started_at: DateTime<Utc>,
|
||||
history: GameStory,
|
||||
state: GameState,
|
||||
}
|
||||
|
|
@ -44,12 +48,36 @@ impl Game {
|
|||
pub fn new(players: &[Identification], settings: GameSettings) -> Result<Self> {
|
||||
let village = Village::new(players, settings)?;
|
||||
Ok(Self {
|
||||
id: GameId::new(),
|
||||
started_at: Utc::now(),
|
||||
history: GameStory::new(village.clone()),
|
||||
state: GameState::Night {
|
||||
night: Night::new(village)?,
|
||||
},
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "server")]
|
||||
pub const fn new_from_parts(
|
||||
id: GameId,
|
||||
started_at: DateTime<Utc>,
|
||||
history: GameStory,
|
||||
state: GameState,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
started_at,
|
||||
history,
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn game_id(&self) -> GameId {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub const fn started_at(&self) -> DateTime<Utc> {
|
||||
self.started_at
|
||||
}
|
||||
|
||||
pub const fn village(&self) -> &Village {
|
||||
match &self.state {
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
crate::id_impl!(GameId);
|
||||
|
|
@ -1,13 +1,109 @@
|
|||
#![allow(clippy::new_without_default)]
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
pub mod cbor;
|
||||
pub mod character;
|
||||
pub mod diedto;
|
||||
pub mod error;
|
||||
pub mod game;
|
||||
#[cfg(test)]
|
||||
mod game_test;
|
||||
pub mod id;
|
||||
pub mod limited;
|
||||
pub mod message;
|
||||
pub mod modifier;
|
||||
pub mod nonzero;
|
||||
pub mod player;
|
||||
pub mod role;
|
||||
pub mod token;
|
||||
pub mod user;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! id_impl {
|
||||
($name:ident) => {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||
pub struct $name(uuid::Uuid);
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl sqlx::TypeInfo for $name {
|
||||
fn is_null(&self) -> bool {
|
||||
self.0 == uuid::Uuid::nil()
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"uuid"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl sqlx::Type<sqlx::Postgres> for $name {
|
||||
fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
|
||||
<uuid::Uuid as sqlx::Type<sqlx::Postgres>>::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl<'q> sqlx::Encode<'q, sqlx::Postgres> for $name {
|
||||
fn encode_by_ref(
|
||||
&self,
|
||||
buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>,
|
||||
) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
|
||||
self.0.encode_by_ref(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl<'r> sqlx::Decode<'r, sqlx::Postgres> for $name {
|
||||
fn decode(
|
||||
value: <sqlx::Postgres as sqlx::Database>::ValueRef<'r>,
|
||||
) -> Result<Self, sqlx::error::BoxDynError> {
|
||||
Ok(Self(uuid::Uuid::decode(value)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<uuid::Uuid> for $name {
|
||||
fn from(value: uuid::Uuid) -> Self {
|
||||
Self::from_uuid(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<$name> for uuid::Uuid {
|
||||
fn from(value: $name) -> Self {
|
||||
value.into_uuid()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for $name {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl $name {
|
||||
pub fn new() -> Self {
|
||||
Self(uuid::Uuid::new_v4())
|
||||
}
|
||||
|
||||
pub const fn from_uuid(uuid: uuid::Uuid) -> Self {
|
||||
Self(uuid)
|
||||
}
|
||||
|
||||
pub const fn into_uuid(self) -> uuid::Uuid {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::str::FromStr for $name {
|
||||
type Err = uuid::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(Self(uuid::Uuid::from_str(s)?))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
use core::{
|
||||
fmt::Display,
|
||||
ops::{Deref, RangeInclusive},
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct FixedLenString<const LEN: usize>(String);
|
||||
|
||||
impl<const LEN: usize> Display for FixedLenString<LEN> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LEN: usize> FixedLenString<LEN> {
|
||||
pub fn new(s: String) -> Option<Self> {
|
||||
(s.chars().take(LEN + 1).count() == LEN).then_some(Self(s))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LEN: usize> Deref for FixedLenString<LEN> {
|
||||
type Target = String;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, const LEN: usize> Deserialize<'de> for FixedLenString<LEN> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct ExpectedLen(usize);
|
||||
impl serde::de::Expected for ExpectedLen {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "a string exactly {} characters long", self.0)
|
||||
}
|
||||
}
|
||||
<String as Deserialize>::deserialize(deserializer).and_then(|s| {
|
||||
let char_count = s.chars().take(LEN.saturating_add(1)).count();
|
||||
if char_count != LEN {
|
||||
Err(serde::de::Error::invalid_length(
|
||||
char_count,
|
||||
&ExpectedLen(LEN),
|
||||
))
|
||||
} else {
|
||||
Ok(Self(s))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LEN: usize> Serialize for FixedLenString<LEN> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.0.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct ClampedString<const MIN: usize, const MAX: usize>(String);
|
||||
|
||||
impl<const MIN: usize, const MAX: usize> ClampedString<MIN, MAX> {
|
||||
pub fn new(s: String) -> Result<Self, RangeInclusive<usize>> {
|
||||
let str_len = s.chars().take(MAX.saturating_add(1)).count();
|
||||
(str_len >= MIN && str_len <= MAX)
|
||||
.then_some(Self(s))
|
||||
.ok_or(MIN..=MAX)
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> String {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const MIN: usize, const MAX: usize> Display for ClampedString<MIN, MAX> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const MIN: usize, const MAX: usize> Deref for ClampedString<MIN, MAX> {
|
||||
type Target = String;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, const MIN: usize, const MAX: usize> Deserialize<'de> for ClampedString<MIN, MAX> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct ExpectedLen(usize, usize);
|
||||
impl serde::de::Expected for ExpectedLen {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"a string between {} and {} characters long",
|
||||
self.0, self.1
|
||||
)
|
||||
}
|
||||
}
|
||||
<String as Deserialize>::deserialize(deserializer).and_then(|s| {
|
||||
let char_count = s.chars().take(MAX.saturating_add(1)).count();
|
||||
if char_count < MIN || char_count > MAX {
|
||||
Err(serde::de::Error::invalid_length(
|
||||
char_count,
|
||||
&ExpectedLen(MIN, MAX),
|
||||
))
|
||||
} else {
|
||||
Ok(Self(s))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<const MIN: usize, const MAX: usize> Serialize for ClampedString<MIN, MAX> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.0.as_str())
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,12 @@ impl PlayerId {
|
|||
pub const fn from_u128(v: u128) -> Self {
|
||||
Self(uuid::Uuid::from_u128(v))
|
||||
}
|
||||
pub const fn from_uuid(v: uuid::Uuid) -> Self {
|
||||
Self(v)
|
||||
}
|
||||
pub const fn into_uuid(self) -> uuid::Uuid {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PlayerId {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{limited::FixedLenString, user::Username};
|
||||
|
||||
pub const TOKEN_LEN: usize = 0x20;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Token {
|
||||
pub token: FixedLenString<TOKEN_LEN>,
|
||||
pub username: Username,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Token {
|
||||
pub fn login_token(&self) -> TokenLogin {
|
||||
TokenLogin(self.token.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct TokenLogin(pub FixedLenString<TOKEN_LEN>);
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
impl axum_extra::headers::authorization::Credentials for TokenLogin {
|
||||
const SCHEME: &'static str = "Bearer";
|
||||
|
||||
fn decode(value: &axum::http::HeaderValue) -> Option<Self> {
|
||||
value
|
||||
.to_str()
|
||||
.ok()
|
||||
.and_then(|v| FixedLenString::new(v.strip_prefix("Bearer ").unwrap_or(v).to_string()))
|
||||
.map(Self)
|
||||
}
|
||||
|
||||
fn encode(&self) -> axum::http::HeaderValue {
|
||||
axum::http::HeaderValue::from_str(self.0.as_str()).expect("bearer token encode")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::limited::ClampedString;
|
||||
|
||||
pub type Username = ClampedString<1, 0x40>;
|
||||
pub type Password = ClampedString<6, 0x100>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct UserLogin {
|
||||
pub username: Username,
|
||||
pub password: Password,
|
||||
}
|
||||
|
||||
crate::id_impl!(UserId);
|
||||
|
|
@ -11,12 +11,12 @@ pretty_env_logger = { version = "0.5" }
|
|||
# env_logger = { version = "0.11" }
|
||||
futures = "0.3.31"
|
||||
anyhow = { version = "1" }
|
||||
werewolves-proto = { path = "../werewolves-proto" }
|
||||
werewolves-proto = { path = "../werewolves-proto", features = ["server"] }
|
||||
werewolves-macros = { path = "../werewolves-macros" }
|
||||
mime-sniffer = { version = "0.1" }
|
||||
chrono = { version = "0.4" }
|
||||
atom_syndication = { version = "0.12" }
|
||||
axum-extra = { version = "0.10", features = ["typed-header"] }
|
||||
axum-extra = { version = "0.12", features = ["typed-header"] }
|
||||
rand = { version = "0.9" }
|
||||
serde_json = { version = "1.0" }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
|
@ -24,8 +24,27 @@ thiserror = { version = "2" }
|
|||
ciborium = { version = "0.2", optional = true }
|
||||
colored = { version = "3.0" }
|
||||
fast_qr = { version = "0.13", features = ["svg"] }
|
||||
ron = "0.8"
|
||||
ron = "0.11"
|
||||
bytes = { version = "1.10" }
|
||||
sqlx = { version = "0.8", features = [
|
||||
"runtime-tokio",
|
||||
"postgres",
|
||||
"derive",
|
||||
"macros",
|
||||
"uuid",
|
||||
"chrono",
|
||||
] }
|
||||
argon2 = { version = "0.5" }
|
||||
tower-http = { version = "0.6", features = ["cors"] }
|
||||
tower = { version = "0.5.2", features = [
|
||||
"limit",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"buffer",
|
||||
"timeout",
|
||||
] }
|
||||
|
||||
|
||||
[features]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,150 @@
|
|||
use sqlx::{Pool, Postgres, query};
|
||||
use werewolves_proto::{
|
||||
error::DatabaseError,
|
||||
game::{Game, GameOver, story::GameStory},
|
||||
id::GameId,
|
||||
player::PlayerId,
|
||||
user::UserId,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GameDatabase {
|
||||
pub(super) pool: Pool<Postgres>,
|
||||
}
|
||||
|
||||
impl GameDatabase {
|
||||
pub async fn new_game(&self, game: &Game) -> Result<(), DatabaseError> {
|
||||
let state = serde_json::to_value(game.game_state())?;
|
||||
let story = serde_json::to_value(&game.story())?;
|
||||
|
||||
let mut tx = self.pool.begin().await?;
|
||||
query!(
|
||||
r#" insert into
|
||||
games (id, started_at, state, story)
|
||||
values
|
||||
($1, $2, $3, $4)"#,
|
||||
game.game_id().into_uuid(),
|
||||
game.started_at(),
|
||||
state,
|
||||
story,
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let player_ids = game
|
||||
.village()
|
||||
.characters()
|
||||
.into_iter()
|
||||
.map(|c| c.player_id().into_uuid())
|
||||
.collect::<Box<[_]>>();
|
||||
let user_ids = query!(
|
||||
r#" select
|
||||
id, user_id
|
||||
from
|
||||
players
|
||||
where
|
||||
id = any($1::uuid[])"#,
|
||||
&*player_ids,
|
||||
)
|
||||
.fetch_all(&mut *tx)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|r| (PlayerId::from_uuid(r.id), r.user_id.map(UserId::from_uuid)))
|
||||
.unzip::<PlayerId, Option<UserId>, Vec<PlayerId>, Vec<Option<UserId>>>();
|
||||
|
||||
let game_id = game.game_id().into_uuid();
|
||||
let game_ids = (0..player_ids.len()).map(|_| game_id).collect::<Box<[_]>>();
|
||||
|
||||
query!(
|
||||
r#" with
|
||||
game_ids as (select row_number() over(), * from unnest($1::uuid[]) as game_id),
|
||||
player_ids as (select row_number() over(), * from unnest($2::uuid[]) as player_id)
|
||||
insert into
|
||||
game_players
|
||||
select
|
||||
game_ids.game_id, player_ids.player_id
|
||||
from
|
||||
game_ids
|
||||
join
|
||||
player_ids on game_ids.row_number = player_ids.row_number
|
||||
"#,
|
||||
&*game_ids,
|
||||
&*player_ids,
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_game(&self, game: &Game) -> Result<(), DatabaseError> {
|
||||
let state = serde_json::to_value(game.game_state())?;
|
||||
let story = serde_json::to_value(game.story())?;
|
||||
|
||||
query!(
|
||||
r#" update
|
||||
games
|
||||
set
|
||||
story = $2,
|
||||
state = $3,
|
||||
outcome = $4,
|
||||
updated_at = now()
|
||||
where
|
||||
id = $1"#,
|
||||
game.game_id().into_uuid(),
|
||||
story,
|
||||
state,
|
||||
game.game_over().map(Self::outcome_to_db_outcome) as _
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_active_game(&self) -> Result<Game, DatabaseError> {
|
||||
let game = query!(
|
||||
r#" select
|
||||
id, state, story, started_at
|
||||
from
|
||||
games
|
||||
where
|
||||
outcome is null"#
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(Game::new_from_parts(
|
||||
GameId::from_uuid(game.id),
|
||||
game.started_at,
|
||||
serde_json::from_value(game.story)?,
|
||||
serde_json::from_value(game.state)?,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn get_game_story(&self, id: GameId) -> Result<GameStory, DatabaseError> {
|
||||
let game = query!(
|
||||
r#" select
|
||||
story
|
||||
from
|
||||
games
|
||||
where
|
||||
id = $1
|
||||
and
|
||||
outcome is not null"#,
|
||||
id.into_uuid(),
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(serde_json::from_value(game.story)?)
|
||||
}
|
||||
|
||||
fn outcome_to_db_outcome(outcome: GameOver) -> &'static str {
|
||||
match outcome {
|
||||
GameOver::VillageWins => "village_victory",
|
||||
GameOver::WolvesWin => "wolves_victory",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
pub mod game;
|
||||
pub mod user;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use werewolves_proto::error::DatabaseError;
|
||||
|
||||
use crate::db::{game::GameDatabase, user::UserDatabase};
|
||||
|
||||
type Result<T> = core::result::Result<T, DatabaseError>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Database {
|
||||
pool: Pool<Postgres>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub const fn new(pool: Pool<Postgres>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
pub fn user(&self) -> UserDatabase {
|
||||
UserDatabase {
|
||||
pool: self.pool.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn game(&self) -> GameDatabase {
|
||||
GameDatabase {
|
||||
pool: self.pool.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn migrate(&self) {
|
||||
log::info!("running migrations");
|
||||
sqlx::migrate!("../migrations")
|
||||
.run(&self.pool)
|
||||
.await
|
||||
.expect("run migrations");
|
||||
log::info!("migrations done");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
use super::Result;
|
||||
use argon2::{
|
||||
Argon2, PasswordHash, PasswordVerifier,
|
||||
password_hash::{PasswordHasher, SaltString, rand_core::OsRng},
|
||||
};
|
||||
use chrono::{TimeDelta, Utc};
|
||||
|
||||
use rand::distr::SampleString;
|
||||
use sqlx::{Decode, Encode, Pool, Postgres, prelude::FromRow, query, query_as};
|
||||
use werewolves_proto::{
|
||||
error::{DatabaseError, ServerError},
|
||||
token,
|
||||
user::UserId,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserDatabase {
|
||||
pub(super) pool: Pool<Postgres>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
pub struct LoginToken {
|
||||
pub token: String,
|
||||
pub user_id: UserId,
|
||||
|
||||
pub created_at: chrono::DateTime<Utc>,
|
||||
pub expires_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl LoginToken {
|
||||
const TOKEN_LONGEVITY: TimeDelta = TimeDelta::days(30);
|
||||
|
||||
pub fn new(user_id: UserId) -> Self {
|
||||
let created_at = Utc::now();
|
||||
let expires_at = created_at
|
||||
.checked_add_signed(Self::TOKEN_LONGEVITY)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"could not add {} time to {created_at}",
|
||||
Self::TOKEN_LONGEVITY
|
||||
)
|
||||
});
|
||||
|
||||
let token = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), token::TOKEN_LEN);
|
||||
|
||||
Self {
|
||||
token,
|
||||
user_id,
|
||||
created_at,
|
||||
expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum GetUserBy<'a> {
|
||||
Username(&'a str),
|
||||
Id(UserId),
|
||||
}
|
||||
|
||||
impl UserDatabase {
|
||||
pub async fn create(&self, username: &str, password: &str) -> Result<User> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)?
|
||||
.to_string();
|
||||
|
||||
let now = chrono::offset::Utc::now();
|
||||
|
||||
let user = User {
|
||||
id: UserId::new(),
|
||||
username: username.into(),
|
||||
password_hash,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
query!(
|
||||
r#"insert into users
|
||||
(id, username, password_hash, created_at, updated_at)
|
||||
values
|
||||
($1, $2, $3, $4, $5)"#,
|
||||
user.id.into_uuid(),
|
||||
user.username,
|
||||
user.password_hash,
|
||||
user.created_at,
|
||||
user.updated_at
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
if let sqlx::Error::Database(db_err) = &err
|
||||
&& let Some(constraint) = db_err.constraint()
|
||||
&& constraint == "users_username_unique"
|
||||
{
|
||||
DatabaseError::UserAlreadyExists
|
||||
} else {
|
||||
err.into()
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub async fn get_user(&self, get_user_by: GetUserBy<'_>) -> Result<User> {
|
||||
Ok(match get_user_by {
|
||||
GetUserBy::Username(username) => {
|
||||
query_as!(
|
||||
User,
|
||||
r#"
|
||||
select
|
||||
id, username, password_hash,
|
||||
created_at, updated_at
|
||||
from
|
||||
users
|
||||
where
|
||||
username = $1"#,
|
||||
username
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await?
|
||||
}
|
||||
GetUserBy::Id(id) => {
|
||||
query_as!(
|
||||
User,
|
||||
r#"
|
||||
select
|
||||
id, username, password_hash,
|
||||
created_at, updated_at
|
||||
from
|
||||
users
|
||||
where
|
||||
id = $1"#,
|
||||
id.into_uuid()
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await?
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
&self,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> core::result::Result<LoginToken, ServerError> {
|
||||
let user = self.get_user(GetUserBy::Username(username)).await?;
|
||||
|
||||
let parsed_hash = PasswordHash::new(&user.password_hash).map_err(DatabaseError::from)?;
|
||||
Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.map_err(|err| match err {
|
||||
argon2::password_hash::Error::Password => ServerError::InvalidCredentials,
|
||||
err => ServerError::DatabaseError(err.into()),
|
||||
})?;
|
||||
|
||||
let token = LoginToken::new(user.id);
|
||||
|
||||
query!(
|
||||
r#" insert into login_tokens
|
||||
(token, user_id, created_at, expires_at)
|
||||
values
|
||||
($1, $2, $3, $4)"#,
|
||||
token.token,
|
||||
token.user_id.into_uuid(),
|
||||
token.created_at,
|
||||
token.expires_at
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(Into::<DatabaseError>::into)?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub async fn check_token(&self, token: &str) -> core::result::Result<User, ServerError> {
|
||||
let token = query_as!(
|
||||
LoginToken,
|
||||
r#" select
|
||||
token, user_id, created_at, expires_at
|
||||
from
|
||||
login_tokens
|
||||
where
|
||||
token = $1
|
||||
and
|
||||
expires_at > now()
|
||||
"#,
|
||||
token
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(Into::<DatabaseError>::into)
|
||||
.map_err(|err| match err {
|
||||
DatabaseError::NotFound => ServerError::ExpiredToken,
|
||||
_ => err.into(),
|
||||
})?;
|
||||
|
||||
if Utc::now() >= token.expires_at {
|
||||
return Err(ServerError::ExpiredToken);
|
||||
}
|
||||
|
||||
Ok(self.get_user(GetUserBy::Id(token.user_id)).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow, Encode, Decode)]
|
||||
pub struct User {
|
||||
pub id: UserId,
|
||||
pub username: String,
|
||||
pub password_hash: String,
|
||||
|
||||
pub created_at: chrono::DateTime<Utc>,
|
||||
pub updated_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
|
@ -1,37 +1,56 @@
|
|||
mod client;
|
||||
mod communication;
|
||||
mod connection;
|
||||
mod db;
|
||||
mod game;
|
||||
mod host;
|
||||
mod lobby;
|
||||
mod runner;
|
||||
mod saver;
|
||||
// mod saver;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
BoxError, Router,
|
||||
error_handling::HandleErrorLayer,
|
||||
extract::{Path, State},
|
||||
http::{Request, StatusCode, header},
|
||||
response::IntoResponse,
|
||||
routing::{any, get},
|
||||
routing::{any, get, post, put},
|
||||
};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{self, Authorization},
|
||||
};
|
||||
use axum_extra::headers;
|
||||
use communication::lobby::LobbyComms;
|
||||
use connection::JoinedPlayers;
|
||||
use core::{fmt::Display, net::SocketAddr, str::FromStr};
|
||||
use core::{fmt::Display, net::SocketAddr, str::FromStr, time::Duration};
|
||||
use fast_qr::convert::{Builder, Shape, svg::SvgBuilder};
|
||||
use runner::IdentifiedClientMessage;
|
||||
use std::{env, io::Write, path::Path};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use std::{env, io::Write};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tower::{ServiceBuilder, buffer::BufferLayer, limit::RateLimitLayer};
|
||||
use werewolves_proto::{
|
||||
cbor::Cbor,
|
||||
error::ServerError,
|
||||
id::GameId,
|
||||
limited::FixedLenString,
|
||||
token::{Token, TokenLogin},
|
||||
user::UserLogin,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
communication::{Comms, connect::ConnectUpdate, host::HostComms, player::PlayerIdComms},
|
||||
saver::FileSaver,
|
||||
db::Database,
|
||||
// saver::FileSaver,
|
||||
};
|
||||
|
||||
const DEFAULT_PORT: u16 = 8080;
|
||||
const DEFAULT_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SAVE_DIR: &str = "werewolves-saves/";
|
||||
const DEFAULT_QRCODE_URL: &str = "https://wolf.emilis.dev/";
|
||||
|
||||
const DEFAULT_MAX_PG_CONNECTIONS: u32 = 30;
|
||||
const DEFAULT_PG_CONN_STRING: &str = "postgres:///ww?host=/var/run/postgresql";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// pretty_env_logger::init();
|
||||
|
|
@ -106,39 +125,60 @@ async fn main() {
|
|||
|
||||
let jp_clone = joined_players.clone();
|
||||
|
||||
let path = Path::new(option_env!("SAVE_PATH").unwrap_or(DEFAULT_SAVE_DIR));
|
||||
let pg_pool = PgPoolOptions::new()
|
||||
.max_connections(
|
||||
std::env::var("MAX_DB_CONNECTIONS")
|
||||
.ok()
|
||||
.and_then(|val| u32::from_str(&val).ok())
|
||||
.unwrap_or(DEFAULT_MAX_PG_CONNECTIONS),
|
||||
)
|
||||
.connect(
|
||||
std::env::var("PG_CONN_STRING")
|
||||
.unwrap_or_else(|_| String::from(DEFAULT_PG_CONN_STRING))
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
.expect("could not init db");
|
||||
let db = Database::new(pg_pool);
|
||||
db.migrate().await;
|
||||
|
||||
if let Err(err) = std::fs::create_dir(path)
|
||||
&& !matches!(err.kind(), std::io::ErrorKind::AlreadyExists)
|
||||
{
|
||||
panic!("creating save dir at [{path:?}]: {err}")
|
||||
}
|
||||
|
||||
// Check if we can write to the path
|
||||
{
|
||||
let test_file_path = path.join(".test");
|
||||
if let Err(err) = std::fs::File::create(&test_file_path) {
|
||||
panic!("can't create files in {path:?}: {err}")
|
||||
// let saver = FileSaver::new(path.canonicalize().expect("canonicalizing path"));
|
||||
tokio::spawn({
|
||||
let db = db.clone();
|
||||
async move {
|
||||
crate::runner::run_game(jp_clone, lobby_comms, db).await;
|
||||
panic!("game over");
|
||||
}
|
||||
std::fs::remove_file(&test_file_path).log_err();
|
||||
}
|
||||
|
||||
let saver = FileSaver::new(path.canonicalize().expect("canonicalizing path"));
|
||||
tokio::spawn(async move {
|
||||
crate::runner::run_game(jp_clone, lobby_comms, saver).await;
|
||||
panic!("game over");
|
||||
});
|
||||
let state = AppState {
|
||||
joined_players,
|
||||
host_recv,
|
||||
host_send,
|
||||
send,
|
||||
db,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/connect/client", any(client::handler))
|
||||
.route("/connect/host", any(host::handler))
|
||||
.route("/qrcode", get(handle_qr_code))
|
||||
.route("/s/users", put(signup))
|
||||
.route("/s/tokens", post(signin))
|
||||
.route(
|
||||
"/s/tokens/check",
|
||||
get(check_token).layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(|err: BoxError| async move {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled error: {}", err),
|
||||
)
|
||||
}))
|
||||
.layer(BufferLayer::new(0x100))
|
||||
.layer(RateLimitLayer::new(100, Duration::from_secs(10))),
|
||||
),
|
||||
)
|
||||
.route("/s/games/{id}", get(get_game_by_id))
|
||||
.with_state(state)
|
||||
.fallback(get(handle_http_static));
|
||||
let listener = tokio::net::TcpListener::bind(listen_addr).await.unwrap();
|
||||
|
|
@ -151,15 +191,61 @@ async fn main() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
async fn get_game_by_id(
|
||||
State(AppState { db, .. }): State<AppState>,
|
||||
Path(game_id): Path<GameId>,
|
||||
) -> Result<impl IntoResponse, ServerError> {
|
||||
let story = db.game().get_game_story(game_id).await?;
|
||||
Ok(Cbor(story))
|
||||
}
|
||||
|
||||
async fn check_token(
|
||||
State(AppState { db, .. }): State<AppState>,
|
||||
TypedHeader(Authorization(login)): TypedHeader<Authorization<TokenLogin>>,
|
||||
) -> Result<impl IntoResponse, ServerError> {
|
||||
db.user().check_token(&login.0).await?;
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
async fn signin(
|
||||
State(AppState { db, .. }): State<AppState>,
|
||||
Cbor(UserLogin { username, password }): Cbor<UserLogin>,
|
||||
) -> Result<impl IntoResponse, ServerError> {
|
||||
let token = db.user().login(&username, &password).await?;
|
||||
|
||||
Ok(Cbor(Token {
|
||||
username,
|
||||
token: FixedLenString::new(token.token.clone()).ok_or_else(|| {
|
||||
ServerError::InternalServerError(format!(
|
||||
"could not get a fixed len string for token [{}]",
|
||||
token.token
|
||||
))
|
||||
})?,
|
||||
created_at: token.created_at,
|
||||
expires_at: token.expires_at,
|
||||
})
|
||||
.into_response())
|
||||
}
|
||||
|
||||
async fn signup(
|
||||
State(AppState { db, .. }): State<AppState>,
|
||||
Cbor(UserLogin { username, password }): Cbor<UserLogin>,
|
||||
) -> Result<impl IntoResponse, ServerError> {
|
||||
db.user().create(&username, &password).await?;
|
||||
Ok(StatusCode::CREATED)
|
||||
}
|
||||
|
||||
struct AppState {
|
||||
joined_players: JoinedPlayers,
|
||||
send: broadcast::Sender<IdentifiedClientMessage>,
|
||||
host_send: tokio::sync::mpsc::Sender<werewolves_proto::message::host::HostMessage>,
|
||||
host_recv: broadcast::Receiver<werewolves_proto::message::host::ServerToHostMessage>,
|
||||
db: Database,
|
||||
}
|
||||
impl Clone for AppState {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
db: self.db.clone(),
|
||||
joined_players: self.joined_players.clone(),
|
||||
send: self.send.clone(),
|
||||
host_send: self.host_send.clone(),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ use core::{num::NonZeroU8, time::Duration};
|
|||
use std::sync::Arc;
|
||||
|
||||
use werewolves_proto::{
|
||||
message::{ClientMessage, Identification, host::HostMessage},
|
||||
error::{GameError, ServerError},
|
||||
message::{
|
||||
ClientMessage, Identification,
|
||||
host::{HostMessage, ServerToHostMessage},
|
||||
},
|
||||
player::PlayerId,
|
||||
};
|
||||
|
||||
|
|
@ -10,9 +14,9 @@ use crate::{
|
|||
LogError,
|
||||
communication::lobby::LobbyComms,
|
||||
connection::JoinedPlayers,
|
||||
db::Database,
|
||||
game::{GameEnd, GameRunner},
|
||||
lobby::{Lobby, LobbyPlayers},
|
||||
saver::Saver,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
@ -21,7 +25,7 @@ pub struct IdentifiedClientMessage {
|
|||
pub message: ClientMessage,
|
||||
}
|
||||
|
||||
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut saver: impl Saver) {
|
||||
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, db: Database) {
|
||||
let mut lobby = Lobby::new(joined_players, comms);
|
||||
if let Some(dummies) = option_env!("DUMMY_PLAYERS").and_then(|p| p.parse::<NonZeroU8>().ok()) {
|
||||
log::info!("creating {dummies} dummy players");
|
||||
|
|
@ -32,36 +36,37 @@ pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut save
|
|||
loop {
|
||||
match &mut state {
|
||||
RunningState::Lobby(lobby) => {
|
||||
if let Some(game) = lobby.next().await {
|
||||
if let Some(mut game) = lobby.next().await {
|
||||
if let Err(err) = db.game().new_game(game.proto_game()).await {
|
||||
log::error!("saving new game: {err}; reverting to lobby");
|
||||
game.comms()
|
||||
.host()
|
||||
.send(ServerToHostMessage::Error(GameError::ServerError(
|
||||
ServerError::DatabaseError(err),
|
||||
)))
|
||||
.log_err();
|
||||
|
||||
state = RunningState::Lobby(game.into_lobby());
|
||||
continue;
|
||||
}
|
||||
state = RunningState::Game(game)
|
||||
}
|
||||
}
|
||||
RunningState::Game(game) => {
|
||||
if let Some(result) = game.next().await {
|
||||
match saver.save(game.proto_game()) {
|
||||
Ok(path) => {
|
||||
log::info!("saved game to {path}");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("saving game: {err}");
|
||||
let game_clone = game.proto_game().clone();
|
||||
let mut saver_clone = saver.clone();
|
||||
tokio::spawn(async move {
|
||||
let started = chrono::Utc::now();
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
match saver_clone.save(&game_clone) {
|
||||
Ok(path) => {
|
||||
log::info!("saved game from {started} to {path}");
|
||||
return;
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("saving game from {started}: {err}")
|
||||
}
|
||||
}
|
||||
if let Err(err) = db.game().update_game(game.proto_game()).await {
|
||||
log::error!("saving game ({}): {err}", game.proto_game().game_id());
|
||||
let game_clone = game.proto_game().clone();
|
||||
let db_clone = db.game();
|
||||
tokio::spawn(async move {
|
||||
let started = chrono::Utc::now();
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
if let Err(err) = db_clone.update_game(&game_clone).await {
|
||||
log::error!("saving game from {started}: {err}")
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
state = match state {
|
||||
RunningState::Game(game) => {
|
||||
|
|
@ -69,6 +74,10 @@ pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut save
|
|||
}
|
||||
_ => unsafe { core::hint::unreachable_unchecked() },
|
||||
};
|
||||
} else {
|
||||
if let Err(err) = db.game().update_game(game.proto_game()).await {
|
||||
log::error!("updating game ({}): {err}", game.proto_game().game_id());
|
||||
}
|
||||
}
|
||||
}
|
||||
RunningState::GameOver(end) => {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
[build]
|
||||
target = "index.html" # The index HTML file to drive the bundling process.
|
||||
html_output = "index.html" # The name of the output HTML file.
|
||||
release = true # Build in release mode.
|
||||
release = false # Build in release mode.
|
||||
# release = true # Build in release mode.
|
||||
dist = "dist" # The output dir for all final assets.
|
||||
public_url = "/" # The public URL from which assets are to be served.
|
||||
filehash = true # Whether to include hash values in the output file names.
|
||||
|
|
@ -9,6 +10,6 @@ inject_scripts = true # Whether to inject scripts (and module preloads) in
|
|||
offline = false # Run without network access
|
||||
frozen = false # Require Cargo.lock and cache are up to date
|
||||
locked = false # Require Cargo.lock is up to date
|
||||
# minify = "on_release" # Control minification: can be one of: never, on_release, always
|
||||
minify = "always" # Control minification: can be one of: never, on_release, always
|
||||
minify = "on_release" # Control minification: can be one of: never, on_release, always
|
||||
# minify = "always" # Control minification: can be one of: never, on_release, always
|
||||
no_sri = false # Allow disabling sub-resource integrity (SRI)
|
||||
|
|
|
|||
Loading…
Reference in New Issue