Compare commits

...

1 Commits

Author SHA1 Message Date
emilis e91a019872
wip persistence 2025-11-05 12:25:50 +00:00
19 changed files with 2152 additions and 93 deletions

904
Cargo.lock generated

File diff suppressed because it is too large Load Diff

57
migrations/1_init.sql Normal file
View File

@ -0,0 +1,57 @@
drop table if exists users cascade;
create table users (
id uuid not null default gen_random_uuid() primary key,
name text,
username text not null,
password_hash text not null,
created_at timestamp with time zone not null,
updated_at timestamp with time zone not null,
check (created_at <= updated_at)
);
drop index if exists users_username_idx;
create index users_username_idx on users (username);
drop index if exists users_username_unique;
create unique index users_username_unique on users (lower(username));
drop table if exists login_tokens cascade;
create table login_tokens (
token text not null primary key,
user_id uuid not null references users(id),
created_at timestamp with time zone not null,
expires_at timestamp with time zone not null,
check (created_at < expires_at)
);
drop type if exists game_outcome cascade;
create type game_outcome as enum (
'village_victory',
'wolves_victory'
);
drop table if exists games cascade;
create table games (
id uuid not null primary key,
outcome game_outcome,
state json not null,
story json not null,
started_at timestamp with time zone not null,
updated_at timestamp with time zone not null default now()
);
drop table if exists players;
create table players (
id uuid not null primary key,
user_id uuid references users(id)
);
drop table if exists game_players;
create table game_players (
game_id uuid not null references games(id),
player_id uuid not null references players(id),
primary key (game_id, player_id)
);

View File

@ -11,8 +11,25 @@ serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.17", features = ["v4", "serde"] }
rand = { version = "0.9" }
werewolves-macros = { path = "../werewolves-macros" }
axum = { version = "*", optional = true }
argon2 = { version = "*", optional = true }
sqlx = { version = "*", optional = true }
ciborium = { version = "*", optional = true }
bytes = { version = "1.10.1", features = ["serde"], optional = true }
axum-extra = { version = "*", optional = true }
chrono = { version = "0.4", features = ["serde"] }
[dev-dependencies]
pretty_assertions = { version = "1" }
pretty_env_logger = { version = "0.5" }
colored = { version = "3.0" }
[features]
server = [
"dep:axum",
"dep:sqlx",
"dep:argon2",
"dep:ciborium",
"dep:bytes",
"dep:axum-extra",
]

View File

@ -0,0 +1,156 @@
use axum::{
body::Bytes,
extract::{FromRequest, Request, rejection::BytesRejection},
http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response},
};
use axum_extra::headers::Mime;
use bytes::{BufMut, BytesMut};
use core::fmt::Display;
use serde::{Serialize, de::DeserializeOwned};
const CBOR_CONTENT_TYPE: &str = "application/cbor";
const PLAIN_CONTENT_TYPE: &str = "text/plain";
#[must_use]
pub struct Cbor<T>(pub T);
impl<T> Cbor<T> {
pub const fn new(t: T) -> Self {
Self(t)
}
}
impl<T, S> FromRequest<S> for Cbor<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = CborRejection;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
if !cbor_content_type(req.headers()) {
return Err(CborRejection::MissingCborContentType);
}
let bytes = Bytes::from_request(req, state).await?;
Ok(Self(ciborium::from_reader::<T, _>(&*bytes)?))
}
}
impl<T> IntoResponse for Cbor<T>
where
T: Serialize,
{
fn into_response(self) -> axum::response::Response {
// Extracted into separate fn so it's only compiled once for all T.
fn make_response(buf: BytesMut, ser_result: Result<(), CborRejection>) -> Response {
match ser_result {
Ok(()) => (
[(
header::CONTENT_TYPE,
HeaderValue::from_static(CBOR_CONTENT_TYPE),
)],
buf.freeze(),
)
.into_response(),
Err(err) => err.into_response(),
}
}
// Use a small initial capacity of 128 bytes like serde_json::to_vec
// https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
let mut buf = BytesMut::with_capacity(128).writer();
let res = ciborium::into_writer(&self.0, &mut buf)
.map_err(|err| CborRejection::SerdeRejection(err.to_string()));
make_response(buf.into_inner(), res)
}
}
#[derive(Debug)]
pub enum CborRejection {
MissingCborContentType,
BytesRejection(BytesRejection),
DeserializeRejection(String),
SerdeRejection(String),
}
impl<T: Display> From<ciborium::de::Error<T>> for CborRejection {
fn from(value: ciborium::de::Error<T>) -> Self {
Self::SerdeRejection(match value {
ciborium::de::Error::Io(err) => format!("i/o: {err}"),
ciborium::de::Error::Syntax(offset) => format!("syntax error at {offset}"),
ciborium::de::Error::Semantic(offset, err) => format!(
"semantic parse: {err}{}",
offset
.map(|offset| format!(" at {offset}"))
.unwrap_or_default(),
),
ciborium::de::Error::RecursionLimitExceeded => {
String::from("the input caused serde to recurse too much")
}
})
}
}
impl From<BytesRejection> for CborRejection {
fn from(value: BytesRejection) -> Self {
Self::BytesRejection(value)
}
}
impl IntoResponse for CborRejection {
fn into_response(self) -> axum::response::Response {
match self {
CborRejection::MissingCborContentType => (
StatusCode::BAD_REQUEST,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
String::from("missing cbor content type"),
),
CborRejection::BytesRejection(err) => (
err.status(),
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
format!("bytes rejection: {}", err.body_text()),
),
CborRejection::SerdeRejection(err) => (
StatusCode::BAD_REQUEST,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
err,
),
CborRejection::DeserializeRejection(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
err,
),
}
.into_response()
}
}
fn cbor_content_type(headers: &HeaderMap) -> bool {
let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
return false;
};
let Ok(content_type) = content_type.to_str() else {
return false;
};
let Ok(mime) = content_type.parse::<Mime>() else {
return false;
};
mime.type_() == "application"
&& (mime.subtype() == "cbor" || mime.suffix().is_some_and(|name| name == "cbor"))
}

View File

@ -1,3 +1,6 @@
#[cfg(feature = "server")]
use core::fmt::Display;
use serde::{Deserialize, Serialize};
use thiserror::Error;
@ -81,4 +84,157 @@ pub enum GameError {
MissingTime(GameTime),
#[error("no previous during day")]
NoPreviousDuringDay,
#[error("server error: {0}")]
ServerError(#[from] ServerError),
}
#[derive(Debug, Clone, PartialEq, Error, Serialize, Deserialize)]
pub enum DatabaseError {
#[error("user already exists")]
UserAlreadyExists,
#[error("password hashing error: {0}")]
PasswordHashError(String),
#[error("sqlx error: {0}")]
SqlxError(String),
#[error("not found")]
NotFound,
#[error("serde_json: {0}")]
SerdeJson(String),
}
#[cfg(feature = "server")]
impl From<serde_json::Error> for DatabaseError {
fn from(value: serde_json::Error) -> Self {
Self::SerdeJson(value.to_string())
}
}
#[cfg(feature = "server")]
impl axum::response::IntoResponse for DatabaseError {
fn into_response(self) -> axum::response::Response {
use axum::http::StatusCode;
use crate::cbor::Cbor;
(
match self {
DatabaseError::UserAlreadyExists => StatusCode::BAD_REQUEST,
DatabaseError::NotFound => StatusCode::NOT_FOUND,
_ => StatusCode::INTERNAL_SERVER_ERROR,
},
Cbor(self),
)
.into_response()
}
}
#[cfg(feature = "server")]
impl From<sqlx::Error> for DatabaseError {
fn from(err: sqlx::Error) -> Self {
match err {
sqlx::Error::RowNotFound => Self::NotFound,
_ => Self::SqlxError(err.to_string()),
}
}
}
#[cfg(feature = "server")]
impl From<argon2::password_hash::Error> for DatabaseError {
fn from(err: argon2::password_hash::Error) -> Self {
Self::PasswordHashError(err.to_string())
}
}
#[derive(Debug, Clone, PartialEq, Error, Serialize, Deserialize)]
pub enum ServerError {
#[error("database error: {0}")]
DatabaseError(DatabaseError),
#[error("invalid credentials")]
InvalidCredentials,
#[error("token expired")]
ExpiredToken,
#[error("internal server error: {0}")]
InternalServerError(String),
#[error("connection error")]
ConnectionError,
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("not found")]
NotFound,
}
impl<I: Into<DatabaseError>> From<I> for ServerError {
fn from(value: I) -> Self {
let database_err: DatabaseError = value.into();
if let DatabaseError::NotFound = &database_err {
return Self::NotFound;
}
Self::DatabaseError(database_err)
}
}
impl From<GameError> for ServerError {
fn from(value: GameError) -> Self {
Self::InvalidRequest(value.to_string())
}
}
#[cfg(feature = "server")]
impl<T: Display> From<ciborium::de::Error<T>> for ServerError {
fn from(_: ciborium::de::Error<T>) -> Self {
Self::InvalidRequest(String::from("could not decode request"))
}
}
#[cfg(feature = "server")]
impl axum::response::IntoResponse for ServerError {
fn into_response(self) -> axum::response::Response {
use axum::http::StatusCode;
use crate::cbor::Cbor;
match self {
ServerError::ExpiredToken => {
(StatusCode::UNAUTHORIZED, Cbor(ServerError::ExpiredToken)).into_response()
}
ServerError::NotFound | ServerError::DatabaseError(DatabaseError::NotFound) => {
(StatusCode::NOT_FOUND, Cbor(ServerError::NotFound)).into_response()
}
ServerError::DatabaseError(DatabaseError::UserAlreadyExists) => (
StatusCode::BAD_REQUEST,
Cbor(ServerError::InvalidRequest(String::from("username taken"))),
)
.into_response(),
ServerError::DatabaseError(err) => {
use uuid::Uuid;
let error_id = Uuid::new_v4();
log::error!("database error[{error_id}]: {err}");
(
StatusCode::INTERNAL_SERVER_ERROR,
Cbor(ServerError::InternalServerError(format!(
"internal server error. error id: {error_id}"
))),
)
.into_response()
}
ServerError::InvalidCredentials => (
StatusCode::UNAUTHORIZED,
Cbor(ServerError::InvalidCredentials),
)
.into_response(),
ServerError::InternalServerError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, Cbor(self)).into_response()
}
ServerError::ConnectionError => {
(StatusCode::BAD_REQUEST, Cbor(ServerError::ConnectionError)).into_response()
}
ServerError::InvalidRequest(reason) => (
StatusCode::BAD_REQUEST,
Cbor(ServerError::InvalidRequest(reason)),
)
.into_response(),
}
}
}

View File

@ -10,6 +10,7 @@ use core::{
ops::{Deref, Range, RangeBounds},
};
use chrono::{DateTime, Utc};
use rand::{Rng, seq::SliceRandom};
use serde::{Deserialize, Serialize};
@ -20,6 +21,7 @@ use crate::{
night::{Night, ServerAction},
story::{DayDetail, GameActions, GameStory, NightDetails},
},
id::GameId,
message::{
CharacterState, Identification,
host::{HostDayMessage, HostGameMessage, HostNightMessage, ServerToHostMessage},
@ -36,6 +38,8 @@ type Result<T> = core::result::Result<T, GameError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Game {
id: GameId,
started_at: DateTime<Utc>,
history: GameStory,
state: GameState,
}
@ -44,12 +48,36 @@ impl Game {
pub fn new(players: &[Identification], settings: GameSettings) -> Result<Self> {
let village = Village::new(players, settings)?;
Ok(Self {
id: GameId::new(),
started_at: Utc::now(),
history: GameStory::new(village.clone()),
state: GameState::Night {
night: Night::new(village)?,
},
})
}
#[cfg(feature = "server")]
pub const fn new_from_parts(
id: GameId,
started_at: DateTime<Utc>,
history: GameStory,
state: GameState,
) -> Self {
Self {
id,
started_at,
history,
state,
}
}
pub const fn game_id(&self) -> GameId {
self.id
}
pub const fn started_at(&self) -> DateTime<Utc> {
self.started_at
}
pub const fn village(&self) -> &Village {
match &self.state {

View File

@ -0,0 +1 @@
crate::id_impl!(GameId);

View File

@ -1,13 +1,109 @@
#![allow(clippy::new_without_default)]
#[cfg(feature = "server")]
pub mod cbor;
pub mod character;
pub mod diedto;
pub mod error;
pub mod game;
#[cfg(test)]
mod game_test;
pub mod id;
pub mod limited;
pub mod message;
pub mod modifier;
pub mod nonzero;
pub mod player;
pub mod role;
pub mod token;
pub mod user;
#[macro_export]
macro_rules! id_impl {
($name:ident) => {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct $name(uuid::Uuid);
#[cfg(feature = "server")]
impl sqlx::TypeInfo for $name {
fn is_null(&self) -> bool {
self.0 == uuid::Uuid::nil()
}
fn name(&self) -> &str {
"uuid"
}
}
#[cfg(feature = "server")]
impl sqlx::Type<sqlx::Postgres> for $name {
fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
<uuid::Uuid as sqlx::Type<sqlx::Postgres>>::type_info()
}
}
#[cfg(feature = "server")]
impl<'q> sqlx::Encode<'q, sqlx::Postgres> for $name {
fn encode_by_ref(
&self,
buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>,
) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
self.0.encode_by_ref(buf)
}
}
#[cfg(feature = "server")]
impl<'r> sqlx::Decode<'r, sqlx::Postgres> for $name {
fn decode(
value: <sqlx::Postgres as sqlx::Database>::ValueRef<'r>,
) -> Result<Self, sqlx::error::BoxDynError> {
Ok(Self(uuid::Uuid::decode(value)?))
}
}
impl From<uuid::Uuid> for $name {
fn from(value: uuid::Uuid) -> Self {
Self::from_uuid(value)
}
}
impl From<$name> for uuid::Uuid {
fn from(value: $name) -> Self {
value.into_uuid()
}
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
impl $name {
pub fn new() -> Self {
Self(uuid::Uuid::new_v4())
}
pub const fn from_uuid(uuid: uuid::Uuid) -> Self {
Self(uuid)
}
pub const fn into_uuid(self) -> uuid::Uuid {
self.0
}
}
impl core::fmt::Display for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl core::str::FromStr for $name {
type Err = uuid::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(uuid::Uuid::from_str(s)?))
}
}
};
}

View File

@ -0,0 +1,129 @@
use core::{
fmt::Display,
ops::{Deref, RangeInclusive},
};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct FixedLenString<const LEN: usize>(String);
impl<const LEN: usize> Display for FixedLenString<LEN> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<const LEN: usize> FixedLenString<LEN> {
pub fn new(s: String) -> Option<Self> {
(s.chars().take(LEN + 1).count() == LEN).then_some(Self(s))
}
}
impl<const LEN: usize> Deref for FixedLenString<LEN> {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de, const LEN: usize> Deserialize<'de> for FixedLenString<LEN> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ExpectedLen(usize);
impl serde::de::Expected for ExpectedLen {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a string exactly {} characters long", self.0)
}
}
<String as Deserialize>::deserialize(deserializer).and_then(|s| {
let char_count = s.chars().take(LEN.saturating_add(1)).count();
if char_count != LEN {
Err(serde::de::Error::invalid_length(
char_count,
&ExpectedLen(LEN),
))
} else {
Ok(Self(s))
}
})
}
}
impl<const LEN: usize> Serialize for FixedLenString<LEN> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.0.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ClampedString<const MIN: usize, const MAX: usize>(String);
impl<const MIN: usize, const MAX: usize> ClampedString<MIN, MAX> {
pub fn new(s: String) -> Result<Self, RangeInclusive<usize>> {
let str_len = s.chars().take(MAX.saturating_add(1)).count();
(str_len >= MIN && str_len <= MAX)
.then_some(Self(s))
.ok_or(MIN..=MAX)
}
pub fn into_inner(self) -> String {
self.0
}
}
impl<const MIN: usize, const MAX: usize> Display for ClampedString<MIN, MAX> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<const MIN: usize, const MAX: usize> Deref for ClampedString<MIN, MAX> {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de, const MIN: usize, const MAX: usize> Deserialize<'de> for ClampedString<MIN, MAX> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ExpectedLen(usize, usize);
impl serde::de::Expected for ExpectedLen {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"a string between {} and {} characters long",
self.0, self.1
)
}
}
<String as Deserialize>::deserialize(deserializer).and_then(|s| {
let char_count = s.chars().take(MAX.saturating_add(1)).count();
if char_count < MIN || char_count > MAX {
Err(serde::de::Error::invalid_length(
char_count,
&ExpectedLen(MIN, MAX),
))
} else {
Ok(Self(s))
}
})
}
}
impl<const MIN: usize, const MAX: usize> Serialize for ClampedString<MIN, MAX> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.0.as_str())
}
}

View File

@ -17,6 +17,12 @@ impl PlayerId {
pub const fn from_u128(v: u128) -> Self {
Self(uuid::Uuid::from_u128(v))
}
pub const fn from_uuid(v: uuid::Uuid) -> Self {
Self(v)
}
pub const fn into_uuid(self) -> uuid::Uuid {
self.0
}
}
impl Display for PlayerId {

View File

@ -0,0 +1,40 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::{limited::FixedLenString, user::Username};
pub const TOKEN_LEN: usize = 0x20;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Token {
pub token: FixedLenString<TOKEN_LEN>,
pub username: Username,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl Token {
pub fn login_token(&self) -> TokenLogin {
TokenLogin(self.token.clone())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TokenLogin(pub FixedLenString<TOKEN_LEN>);
#[cfg(feature = "server")]
impl axum_extra::headers::authorization::Credentials for TokenLogin {
const SCHEME: &'static str = "Bearer";
fn decode(value: &axum::http::HeaderValue) -> Option<Self> {
value
.to_str()
.ok()
.and_then(|v| FixedLenString::new(v.strip_prefix("Bearer ").unwrap_or(v).to_string()))
.map(Self)
}
fn encode(&self) -> axum::http::HeaderValue {
axum::http::HeaderValue::from_str(self.0.as_str()).expect("bearer token encode")
}
}

View File

@ -0,0 +1,14 @@
use serde::{Deserialize, Serialize};
use crate::limited::ClampedString;
pub type Username = ClampedString<1, 0x40>;
pub type Password = ClampedString<6, 0x100>;
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct UserLogin {
pub username: Username,
pub password: Password,
}
crate::id_impl!(UserId);

View File

@ -11,12 +11,12 @@ pretty_env_logger = { version = "0.5" }
# env_logger = { version = "0.11" }
futures = "0.3.31"
anyhow = { version = "1" }
werewolves-proto = { path = "../werewolves-proto" }
werewolves-proto = { path = "../werewolves-proto", features = ["server"] }
werewolves-macros = { path = "../werewolves-macros" }
mime-sniffer = { version = "0.1" }
chrono = { version = "0.4" }
atom_syndication = { version = "0.12" }
axum-extra = { version = "0.10", features = ["typed-header"] }
axum-extra = { version = "0.12", features = ["typed-header"] }
rand = { version = "0.9" }
serde_json = { version = "1.0" }
serde = { version = "1.0", features = ["derive"] }
@ -24,8 +24,27 @@ thiserror = { version = "2" }
ciborium = { version = "0.2", optional = true }
colored = { version = "3.0" }
fast_qr = { version = "0.13", features = ["svg"] }
ron = "0.8"
ron = "0.11"
bytes = { version = "1.10" }
sqlx = { version = "0.8", features = [
"runtime-tokio",
"postgres",
"derive",
"macros",
"uuid",
"chrono",
] }
argon2 = { version = "0.5" }
tower-http = { version = "0.6", features = ["cors"] }
tower = { version = "0.5.2", features = [
"limit",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"buffer",
"timeout",
] }
[features]

View File

@ -0,0 +1,150 @@
use sqlx::{Pool, Postgres, query};
use werewolves_proto::{
error::DatabaseError,
game::{Game, GameOver, story::GameStory},
id::GameId,
player::PlayerId,
user::UserId,
};
#[derive(Debug, Clone)]
pub struct GameDatabase {
pub(super) pool: Pool<Postgres>,
}
impl GameDatabase {
pub async fn new_game(&self, game: &Game) -> Result<(), DatabaseError> {
let state = serde_json::to_value(game.game_state())?;
let story = serde_json::to_value(&game.story())?;
let mut tx = self.pool.begin().await?;
query!(
r#" insert into
games (id, started_at, state, story)
values
($1, $2, $3, $4)"#,
game.game_id().into_uuid(),
game.started_at(),
state,
story,
)
.execute(&mut *tx)
.await?;
let player_ids = game
.village()
.characters()
.into_iter()
.map(|c| c.player_id().into_uuid())
.collect::<Box<[_]>>();
let user_ids = query!(
r#" select
id, user_id
from
players
where
id = any($1::uuid[])"#,
&*player_ids,
)
.fetch_all(&mut *tx)
.await?
.into_iter()
.map(|r| (PlayerId::from_uuid(r.id), r.user_id.map(UserId::from_uuid)))
.unzip::<PlayerId, Option<UserId>, Vec<PlayerId>, Vec<Option<UserId>>>();
let game_id = game.game_id().into_uuid();
let game_ids = (0..player_ids.len()).map(|_| game_id).collect::<Box<[_]>>();
query!(
r#" with
game_ids as (select row_number() over(), * from unnest($1::uuid[]) as game_id),
player_ids as (select row_number() over(), * from unnest($2::uuid[]) as player_id)
insert into
game_players
select
game_ids.game_id, player_ids.player_id
from
game_ids
join
player_ids on game_ids.row_number = player_ids.row_number
"#,
&*game_ids,
&*player_ids,
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
pub async fn update_game(&self, game: &Game) -> Result<(), DatabaseError> {
let state = serde_json::to_value(game.game_state())?;
let story = serde_json::to_value(game.story())?;
query!(
r#" update
games
set
story = $2,
state = $3,
outcome = $4,
updated_at = now()
where
id = $1"#,
game.game_id().into_uuid(),
story,
state,
game.game_over().map(Self::outcome_to_db_outcome) as _
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn get_active_game(&self) -> Result<Game, DatabaseError> {
let game = query!(
r#" select
id, state, story, started_at
from
games
where
outcome is null"#
)
.fetch_one(&self.pool)
.await?;
Ok(Game::new_from_parts(
GameId::from_uuid(game.id),
game.started_at,
serde_json::from_value(game.story)?,
serde_json::from_value(game.state)?,
))
}
pub async fn get_game_story(&self, id: GameId) -> Result<GameStory, DatabaseError> {
let game = query!(
r#" select
story
from
games
where
id = $1
and
outcome is not null"#,
id.into_uuid(),
)
.fetch_one(&self.pool)
.await?;
Ok(serde_json::from_value(game.story)?)
}
fn outcome_to_db_outcome(outcome: GameOver) -> &'static str {
match outcome {
GameOver::VillageWins => "village_victory",
GameOver::WolvesWin => "wolves_victory",
}
}
}

View File

@ -0,0 +1,40 @@
pub mod game;
pub mod user;
use sqlx::{Pool, Postgres};
use werewolves_proto::error::DatabaseError;
use crate::db::{game::GameDatabase, user::UserDatabase};
type Result<T> = core::result::Result<T, DatabaseError>;
#[derive(Debug, Clone)]
pub struct Database {
pool: Pool<Postgres>,
}
impl Database {
pub const fn new(pool: Pool<Postgres>) -> Self {
Self { pool }
}
pub fn user(&self) -> UserDatabase {
UserDatabase {
pool: self.pool.clone(),
}
}
pub fn game(&self) -> GameDatabase {
GameDatabase {
pool: self.pool.clone(),
}
}
pub async fn migrate(&self) {
log::info!("running migrations");
sqlx::migrate!("../migrations")
.run(&self.pool)
.await
.expect("run migrations");
log::info!("migrations done");
}
}

View File

@ -0,0 +1,214 @@
use super::Result;
use argon2::{
Argon2, PasswordHash, PasswordVerifier,
password_hash::{PasswordHasher, SaltString, rand_core::OsRng},
};
use chrono::{TimeDelta, Utc};
use rand::distr::SampleString;
use sqlx::{Decode, Encode, Pool, Postgres, prelude::FromRow, query, query_as};
use werewolves_proto::{
error::{DatabaseError, ServerError},
token,
user::UserId,
};
#[derive(Debug, Clone)]
pub struct UserDatabase {
pub(super) pool: Pool<Postgres>,
}
#[derive(Debug, Clone, FromRow)]
pub struct LoginToken {
pub token: String,
pub user_id: UserId,
pub created_at: chrono::DateTime<Utc>,
pub expires_at: chrono::DateTime<Utc>,
}
impl LoginToken {
const TOKEN_LONGEVITY: TimeDelta = TimeDelta::days(30);
pub fn new(user_id: UserId) -> Self {
let created_at = Utc::now();
let expires_at = created_at
.checked_add_signed(Self::TOKEN_LONGEVITY)
.unwrap_or_else(|| {
panic!(
"could not add {} time to {created_at}",
Self::TOKEN_LONGEVITY
)
});
let token = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), token::TOKEN_LEN);
Self {
token,
user_id,
created_at,
expires_at,
}
}
}
pub enum GetUserBy<'a> {
Username(&'a str),
Id(UserId),
}
impl UserDatabase {
pub async fn create(&self, username: &str, password: &str) -> Result<User> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)?
.to_string();
let now = chrono::offset::Utc::now();
let user = User {
id: UserId::new(),
username: username.into(),
password_hash,
created_at: now,
updated_at: now,
};
query!(
r#"insert into users
(id, username, password_hash, created_at, updated_at)
values
($1, $2, $3, $4, $5)"#,
user.id.into_uuid(),
user.username,
user.password_hash,
user.created_at,
user.updated_at
)
.execute(&self.pool)
.await
.map_err(|err| {
if let sqlx::Error::Database(db_err) = &err
&& let Some(constraint) = db_err.constraint()
&& constraint == "users_username_unique"
{
DatabaseError::UserAlreadyExists
} else {
err.into()
}
})?;
Ok(user)
}
pub async fn get_user(&self, get_user_by: GetUserBy<'_>) -> Result<User> {
Ok(match get_user_by {
GetUserBy::Username(username) => {
query_as!(
User,
r#"
select
id, username, password_hash,
created_at, updated_at
from
users
where
username = $1"#,
username
)
.fetch_one(&self.pool)
.await?
}
GetUserBy::Id(id) => {
query_as!(
User,
r#"
select
id, username, password_hash,
created_at, updated_at
from
users
where
id = $1"#,
id.into_uuid()
)
.fetch_one(&self.pool)
.await?
}
})
}
pub async fn login(
&self,
username: &str,
password: &str,
) -> core::result::Result<LoginToken, ServerError> {
let user = self.get_user(GetUserBy::Username(username)).await?;
let parsed_hash = PasswordHash::new(&user.password_hash).map_err(DatabaseError::from)?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|err| match err {
argon2::password_hash::Error::Password => ServerError::InvalidCredentials,
err => ServerError::DatabaseError(err.into()),
})?;
let token = LoginToken::new(user.id);
query!(
r#" insert into login_tokens
(token, user_id, created_at, expires_at)
values
($1, $2, $3, $4)"#,
token.token,
token.user_id.into_uuid(),
token.created_at,
token.expires_at
)
.execute(&self.pool)
.await
.map_err(Into::<DatabaseError>::into)?;
Ok(token)
}
pub async fn check_token(&self, token: &str) -> core::result::Result<User, ServerError> {
let token = query_as!(
LoginToken,
r#" select
token, user_id, created_at, expires_at
from
login_tokens
where
token = $1
and
expires_at > now()
"#,
token
)
.fetch_one(&self.pool)
.await
.map_err(Into::<DatabaseError>::into)
.map_err(|err| match err {
DatabaseError::NotFound => ServerError::ExpiredToken,
_ => err.into(),
})?;
if Utc::now() >= token.expires_at {
return Err(ServerError::ExpiredToken);
}
Ok(self.get_user(GetUserBy::Id(token.user_id)).await?)
}
}
#[derive(Debug, Clone, FromRow, Encode, Decode)]
pub struct User {
pub id: UserId,
pub username: String,
pub password_hash: String,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}

View File

@ -1,37 +1,56 @@
mod client;
mod communication;
mod connection;
mod db;
mod game;
mod host;
mod lobby;
mod runner;
mod saver;
// mod saver;
use axum::{
Router,
BoxError, Router,
error_handling::HandleErrorLayer,
extract::{Path, State},
http::{Request, StatusCode, header},
response::IntoResponse,
routing::{any, get},
routing::{any, get, post, put},
};
use axum_extra::{
TypedHeader,
headers::{self, Authorization},
};
use axum_extra::headers;
use communication::lobby::LobbyComms;
use connection::JoinedPlayers;
use core::{fmt::Display, net::SocketAddr, str::FromStr};
use core::{fmt::Display, net::SocketAddr, str::FromStr, time::Duration};
use fast_qr::convert::{Builder, Shape, svg::SvgBuilder};
use runner::IdentifiedClientMessage;
use std::{env, io::Write, path::Path};
use sqlx::postgres::PgPoolOptions;
use std::{env, io::Write};
use tokio::sync::{broadcast, mpsc};
use tower::{ServiceBuilder, buffer::BufferLayer, limit::RateLimitLayer};
use werewolves_proto::{
cbor::Cbor,
error::ServerError,
id::GameId,
limited::FixedLenString,
token::{Token, TokenLogin},
user::UserLogin,
};
use crate::{
communication::{Comms, connect::ConnectUpdate, host::HostComms, player::PlayerIdComms},
saver::FileSaver,
db::Database,
// saver::FileSaver,
};
const DEFAULT_PORT: u16 = 8080;
const DEFAULT_HOST: &str = "127.0.0.1";
const DEFAULT_SAVE_DIR: &str = "werewolves-saves/";
const DEFAULT_QRCODE_URL: &str = "https://wolf.emilis.dev/";
const DEFAULT_MAX_PG_CONNECTIONS: u32 = 30;
const DEFAULT_PG_CONN_STRING: &str = "postgres:///ww?host=/var/run/postgresql";
#[tokio::main]
async fn main() {
// pretty_env_logger::init();
@ -106,39 +125,60 @@ async fn main() {
let jp_clone = joined_players.clone();
let path = Path::new(option_env!("SAVE_PATH").unwrap_or(DEFAULT_SAVE_DIR));
let pg_pool = PgPoolOptions::new()
.max_connections(
std::env::var("MAX_DB_CONNECTIONS")
.ok()
.and_then(|val| u32::from_str(&val).ok())
.unwrap_or(DEFAULT_MAX_PG_CONNECTIONS),
)
.connect(
std::env::var("PG_CONN_STRING")
.unwrap_or_else(|_| String::from(DEFAULT_PG_CONN_STRING))
.as_str(),
)
.await
.expect("could not init db");
let db = Database::new(pg_pool);
db.migrate().await;
if let Err(err) = std::fs::create_dir(path)
&& !matches!(err.kind(), std::io::ErrorKind::AlreadyExists)
{
panic!("creating save dir at [{path:?}]: {err}")
}
// Check if we can write to the path
{
let test_file_path = path.join(".test");
if let Err(err) = std::fs::File::create(&test_file_path) {
panic!("can't create files in {path:?}: {err}")
}
std::fs::remove_file(&test_file_path).log_err();
}
let saver = FileSaver::new(path.canonicalize().expect("canonicalizing path"));
tokio::spawn(async move {
crate::runner::run_game(jp_clone, lobby_comms, saver).await;
// let saver = FileSaver::new(path.canonicalize().expect("canonicalizing path"));
tokio::spawn({
let db = db.clone();
async move {
crate::runner::run_game(jp_clone, lobby_comms, db).await;
panic!("game over");
}
});
let state = AppState {
joined_players,
host_recv,
host_send,
send,
db,
};
let app = Router::new()
.route("/connect/client", any(client::handler))
.route("/connect/host", any(host::handler))
.route("/qrcode", get(handle_qr_code))
.route("/s/users", put(signup))
.route("/s/tokens", post(signin))
.route(
"/s/tokens/check",
get(check_token).layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|err: BoxError| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled error: {}", err),
)
}))
.layer(BufferLayer::new(0x100))
.layer(RateLimitLayer::new(100, Duration::from_secs(10))),
),
)
.route("/s/games/{id}", get(get_game_by_id))
.with_state(state)
.fallback(get(handle_http_static));
let listener = tokio::net::TcpListener::bind(listen_addr).await.unwrap();
@ -151,15 +191,61 @@ async fn main() {
.unwrap();
}
async fn get_game_by_id(
State(AppState { db, .. }): State<AppState>,
Path(game_id): Path<GameId>,
) -> Result<impl IntoResponse, ServerError> {
let story = db.game().get_game_story(game_id).await?;
Ok(Cbor(story))
}
async fn check_token(
State(AppState { db, .. }): State<AppState>,
TypedHeader(Authorization(login)): TypedHeader<Authorization<TokenLogin>>,
) -> Result<impl IntoResponse, ServerError> {
db.user().check_token(&login.0).await?;
Ok(StatusCode::OK)
}
async fn signin(
State(AppState { db, .. }): State<AppState>,
Cbor(UserLogin { username, password }): Cbor<UserLogin>,
) -> Result<impl IntoResponse, ServerError> {
let token = db.user().login(&username, &password).await?;
Ok(Cbor(Token {
username,
token: FixedLenString::new(token.token.clone()).ok_or_else(|| {
ServerError::InternalServerError(format!(
"could not get a fixed len string for token [{}]",
token.token
))
})?,
created_at: token.created_at,
expires_at: token.expires_at,
})
.into_response())
}
async fn signup(
State(AppState { db, .. }): State<AppState>,
Cbor(UserLogin { username, password }): Cbor<UserLogin>,
) -> Result<impl IntoResponse, ServerError> {
db.user().create(&username, &password).await?;
Ok(StatusCode::CREATED)
}
struct AppState {
joined_players: JoinedPlayers,
send: broadcast::Sender<IdentifiedClientMessage>,
host_send: tokio::sync::mpsc::Sender<werewolves_proto::message::host::HostMessage>,
host_recv: broadcast::Receiver<werewolves_proto::message::host::ServerToHostMessage>,
db: Database,
}
impl Clone for AppState {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
joined_players: self.joined_players.clone(),
send: self.send.clone(),
host_send: self.host_send.clone(),

View File

@ -2,7 +2,11 @@ use core::{num::NonZeroU8, time::Duration};
use std::sync::Arc;
use werewolves_proto::{
message::{ClientMessage, Identification, host::HostMessage},
error::{GameError, ServerError},
message::{
ClientMessage, Identification,
host::{HostMessage, ServerToHostMessage},
},
player::PlayerId,
};
@ -10,9 +14,9 @@ use crate::{
LogError,
communication::lobby::LobbyComms,
connection::JoinedPlayers,
db::Database,
game::{GameEnd, GameRunner},
lobby::{Lobby, LobbyPlayers},
saver::Saver,
};
#[derive(Debug, Clone, PartialEq)]
@ -21,7 +25,7 @@ pub struct IdentifiedClientMessage {
pub message: ClientMessage,
}
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut saver: impl Saver) {
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, db: Database) {
let mut lobby = Lobby::new(joined_players, comms);
if let Some(dummies) = option_env!("DUMMY_PLAYERS").and_then(|p| p.parse::<NonZeroU8>().ok()) {
log::info!("creating {dummies} dummy players");
@ -32,43 +36,48 @@ pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut save
loop {
match &mut state {
RunningState::Lobby(lobby) => {
if let Some(game) = lobby.next().await {
if let Some(mut game) = lobby.next().await {
if let Err(err) = db.game().new_game(game.proto_game()).await {
log::error!("saving new game: {err}; reverting to lobby");
game.comms()
.host()
.send(ServerToHostMessage::Error(GameError::ServerError(
ServerError::DatabaseError(err),
)))
.log_err();
state = RunningState::Lobby(game.into_lobby());
continue;
}
state = RunningState::Game(game)
}
}
RunningState::Game(game) => {
if let Some(result) = game.next().await {
match saver.save(game.proto_game()) {
Ok(path) => {
log::info!("saved game to {path}");
}
Err(err) => {
log::error!("saving game: {err}");
if let Err(err) = db.game().update_game(game.proto_game()).await {
log::error!("saving game ({}): {err}", game.proto_game().game_id());
let game_clone = game.proto_game().clone();
let mut saver_clone = saver.clone();
let db_clone = db.game();
tokio::spawn(async move {
let started = chrono::Utc::now();
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
match saver_clone.save(&game_clone) {
Ok(path) => {
log::info!("saved game from {started} to {path}");
return;
}
Err(err) => {
if let Err(err) = db_clone.update_game(&game_clone).await {
log::error!("saving game from {started}: {err}")
}
}
}
});
}
}
state = match state {
RunningState::Game(game) => {
RunningState::GameOver(GameEnd::new(game, result))
}
_ => unsafe { core::hint::unreachable_unchecked() },
};
} else {
if let Err(err) = db.game().update_game(game.proto_game()).await {
log::error!("updating game ({}): {err}", game.proto_game().game_id());
}
}
}
RunningState::GameOver(end) => {

View File

@ -1,7 +1,8 @@
[build]
target = "index.html" # The index HTML file to drive the bundling process.
html_output = "index.html" # The name of the output HTML file.
release = true # Build in release mode.
release = false # Build in release mode.
# release = true # Build in release mode.
dist = "dist" # The output dir for all final assets.
public_url = "/" # The public URL from which assets are to be served.
filehash = true # Whether to include hash values in the output file names.
@ -9,6 +10,6 @@ inject_scripts = true # Whether to inject scripts (and module preloads) in
offline = false # Run without network access
frozen = false # Require Cargo.lock and cache are up to date
locked = false # Require Cargo.lock is up to date
# minify = "on_release" # Control minification: can be one of: never, on_release, always
minify = "always" # Control minification: can be one of: never, on_release, always
minify = "on_release" # Control minification: can be one of: never, on_release, always
# minify = "always" # Control minification: can be one of: never, on_release, always
no_sri = false # Allow disabling sub-resource integrity (SRI)