wip persistence

This commit is contained in:
emilis 2025-11-04 22:25:50 +00:00
parent 15a6454ae2
commit e91a019872
No known key found for this signature in database
19 changed files with 2152 additions and 93 deletions

904
Cargo.lock generated

File diff suppressed because it is too large Load Diff

57
migrations/1_init.sql Normal file
View File

@ -0,0 +1,57 @@
drop table if exists users cascade;
create table users (
id uuid not null default gen_random_uuid() primary key,
name text,
username text not null,
password_hash text not null,
created_at timestamp with time zone not null,
updated_at timestamp with time zone not null,
check (created_at <= updated_at)
);
drop index if exists users_username_idx;
create index users_username_idx on users (username);
drop index if exists users_username_unique;
create unique index users_username_unique on users (lower(username));
drop table if exists login_tokens cascade;
create table login_tokens (
token text not null primary key,
user_id uuid not null references users(id),
created_at timestamp with time zone not null,
expires_at timestamp with time zone not null,
check (created_at < expires_at)
);
drop type if exists game_outcome cascade;
create type game_outcome as enum (
'village_victory',
'wolves_victory'
);
drop table if exists games cascade;
create table games (
id uuid not null primary key,
outcome game_outcome,
state json not null,
story json not null,
started_at timestamp with time zone not null,
updated_at timestamp with time zone not null default now()
);
drop table if exists players;
create table players (
id uuid not null primary key,
user_id uuid references users(id)
);
drop table if exists game_players;
create table game_players (
game_id uuid not null references games(id),
player_id uuid not null references players(id),
primary key (game_id, player_id)
);

View File

@ -11,8 +11,25 @@ serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.17", features = ["v4", "serde"] }
rand = { version = "0.9" }
werewolves-macros = { path = "../werewolves-macros" }
axum = { version = "*", optional = true }
argon2 = { version = "*", optional = true }
sqlx = { version = "*", optional = true }
ciborium = { version = "*", optional = true }
bytes = { version = "1.10.1", features = ["serde"], optional = true }
axum-extra = { version = "*", optional = true }
chrono = { version = "0.4", features = ["serde"] }
[dev-dependencies]
pretty_assertions = { version = "1" }
pretty_env_logger = { version = "0.5" }
colored = { version = "3.0" }
[features]
server = [
"dep:axum",
"dep:sqlx",
"dep:argon2",
"dep:ciborium",
"dep:bytes",
"dep:axum-extra",
]

View File

@ -0,0 +1,156 @@
use axum::{
body::Bytes,
extract::{FromRequest, Request, rejection::BytesRejection},
http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response},
};
use axum_extra::headers::Mime;
use bytes::{BufMut, BytesMut};
use core::fmt::Display;
use serde::{Serialize, de::DeserializeOwned};
const CBOR_CONTENT_TYPE: &str = "application/cbor";
const PLAIN_CONTENT_TYPE: &str = "text/plain";
#[must_use]
pub struct Cbor<T>(pub T);
impl<T> Cbor<T> {
pub const fn new(t: T) -> Self {
Self(t)
}
}
impl<T, S> FromRequest<S> for Cbor<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = CborRejection;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
if !cbor_content_type(req.headers()) {
return Err(CborRejection::MissingCborContentType);
}
let bytes = Bytes::from_request(req, state).await?;
Ok(Self(ciborium::from_reader::<T, _>(&*bytes)?))
}
}
impl<T> IntoResponse for Cbor<T>
where
T: Serialize,
{
fn into_response(self) -> axum::response::Response {
// Extracted into separate fn so it's only compiled once for all T.
fn make_response(buf: BytesMut, ser_result: Result<(), CborRejection>) -> Response {
match ser_result {
Ok(()) => (
[(
header::CONTENT_TYPE,
HeaderValue::from_static(CBOR_CONTENT_TYPE),
)],
buf.freeze(),
)
.into_response(),
Err(err) => err.into_response(),
}
}
// Use a small initial capacity of 128 bytes like serde_json::to_vec
// https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
let mut buf = BytesMut::with_capacity(128).writer();
let res = ciborium::into_writer(&self.0, &mut buf)
.map_err(|err| CborRejection::SerdeRejection(err.to_string()));
make_response(buf.into_inner(), res)
}
}
#[derive(Debug)]
pub enum CborRejection {
MissingCborContentType,
BytesRejection(BytesRejection),
DeserializeRejection(String),
SerdeRejection(String),
}
impl<T: Display> From<ciborium::de::Error<T>> for CborRejection {
fn from(value: ciborium::de::Error<T>) -> Self {
Self::SerdeRejection(match value {
ciborium::de::Error::Io(err) => format!("i/o: {err}"),
ciborium::de::Error::Syntax(offset) => format!("syntax error at {offset}"),
ciborium::de::Error::Semantic(offset, err) => format!(
"semantic parse: {err}{}",
offset
.map(|offset| format!(" at {offset}"))
.unwrap_or_default(),
),
ciborium::de::Error::RecursionLimitExceeded => {
String::from("the input caused serde to recurse too much")
}
})
}
}
impl From<BytesRejection> for CborRejection {
fn from(value: BytesRejection) -> Self {
Self::BytesRejection(value)
}
}
impl IntoResponse for CborRejection {
fn into_response(self) -> axum::response::Response {
match self {
CborRejection::MissingCborContentType => (
StatusCode::BAD_REQUEST,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
String::from("missing cbor content type"),
),
CborRejection::BytesRejection(err) => (
err.status(),
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
format!("bytes rejection: {}", err.body_text()),
),
CborRejection::SerdeRejection(err) => (
StatusCode::BAD_REQUEST,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
err,
),
CborRejection::DeserializeRejection(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(
header::CONTENT_TYPE,
HeaderValue::from_static(PLAIN_CONTENT_TYPE),
)],
err,
),
}
.into_response()
}
}
fn cbor_content_type(headers: &HeaderMap) -> bool {
let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
return false;
};
let Ok(content_type) = content_type.to_str() else {
return false;
};
let Ok(mime) = content_type.parse::<Mime>() else {
return false;
};
mime.type_() == "application"
&& (mime.subtype() == "cbor" || mime.suffix().is_some_and(|name| name == "cbor"))
}

View File

@ -1,3 +1,6 @@
#[cfg(feature = "server")]
use core::fmt::Display;
use serde::{Deserialize, Serialize};
use thiserror::Error;
@ -81,4 +84,157 @@ pub enum GameError {
MissingTime(GameTime),
#[error("no previous during day")]
NoPreviousDuringDay,
#[error("server error: {0}")]
ServerError(#[from] ServerError),
}
#[derive(Debug, Clone, PartialEq, Error, Serialize, Deserialize)]
pub enum DatabaseError {
#[error("user already exists")]
UserAlreadyExists,
#[error("password hashing error: {0}")]
PasswordHashError(String),
#[error("sqlx error: {0}")]
SqlxError(String),
#[error("not found")]
NotFound,
#[error("serde_json: {0}")]
SerdeJson(String),
}
#[cfg(feature = "server")]
impl From<serde_json::Error> for DatabaseError {
fn from(value: serde_json::Error) -> Self {
Self::SerdeJson(value.to_string())
}
}
#[cfg(feature = "server")]
impl axum::response::IntoResponse for DatabaseError {
fn into_response(self) -> axum::response::Response {
use axum::http::StatusCode;
use crate::cbor::Cbor;
(
match self {
DatabaseError::UserAlreadyExists => StatusCode::BAD_REQUEST,
DatabaseError::NotFound => StatusCode::NOT_FOUND,
_ => StatusCode::INTERNAL_SERVER_ERROR,
},
Cbor(self),
)
.into_response()
}
}
#[cfg(feature = "server")]
impl From<sqlx::Error> for DatabaseError {
fn from(err: sqlx::Error) -> Self {
match err {
sqlx::Error::RowNotFound => Self::NotFound,
_ => Self::SqlxError(err.to_string()),
}
}
}
#[cfg(feature = "server")]
impl From<argon2::password_hash::Error> for DatabaseError {
fn from(err: argon2::password_hash::Error) -> Self {
Self::PasswordHashError(err.to_string())
}
}
#[derive(Debug, Clone, PartialEq, Error, Serialize, Deserialize)]
pub enum ServerError {
#[error("database error: {0}")]
DatabaseError(DatabaseError),
#[error("invalid credentials")]
InvalidCredentials,
#[error("token expired")]
ExpiredToken,
#[error("internal server error: {0}")]
InternalServerError(String),
#[error("connection error")]
ConnectionError,
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("not found")]
NotFound,
}
impl<I: Into<DatabaseError>> From<I> for ServerError {
fn from(value: I) -> Self {
let database_err: DatabaseError = value.into();
if let DatabaseError::NotFound = &database_err {
return Self::NotFound;
}
Self::DatabaseError(database_err)
}
}
impl From<GameError> for ServerError {
fn from(value: GameError) -> Self {
Self::InvalidRequest(value.to_string())
}
}
#[cfg(feature = "server")]
impl<T: Display> From<ciborium::de::Error<T>> for ServerError {
fn from(_: ciborium::de::Error<T>) -> Self {
Self::InvalidRequest(String::from("could not decode request"))
}
}
#[cfg(feature = "server")]
impl axum::response::IntoResponse for ServerError {
fn into_response(self) -> axum::response::Response {
use axum::http::StatusCode;
use crate::cbor::Cbor;
match self {
ServerError::ExpiredToken => {
(StatusCode::UNAUTHORIZED, Cbor(ServerError::ExpiredToken)).into_response()
}
ServerError::NotFound | ServerError::DatabaseError(DatabaseError::NotFound) => {
(StatusCode::NOT_FOUND, Cbor(ServerError::NotFound)).into_response()
}
ServerError::DatabaseError(DatabaseError::UserAlreadyExists) => (
StatusCode::BAD_REQUEST,
Cbor(ServerError::InvalidRequest(String::from("username taken"))),
)
.into_response(),
ServerError::DatabaseError(err) => {
use uuid::Uuid;
let error_id = Uuid::new_v4();
log::error!("database error[{error_id}]: {err}");
(
StatusCode::INTERNAL_SERVER_ERROR,
Cbor(ServerError::InternalServerError(format!(
"internal server error. error id: {error_id}"
))),
)
.into_response()
}
ServerError::InvalidCredentials => (
StatusCode::UNAUTHORIZED,
Cbor(ServerError::InvalidCredentials),
)
.into_response(),
ServerError::InternalServerError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, Cbor(self)).into_response()
}
ServerError::ConnectionError => {
(StatusCode::BAD_REQUEST, Cbor(ServerError::ConnectionError)).into_response()
}
ServerError::InvalidRequest(reason) => (
StatusCode::BAD_REQUEST,
Cbor(ServerError::InvalidRequest(reason)),
)
.into_response(),
}
}
}

View File

@ -10,6 +10,7 @@ use core::{
ops::{Deref, Range, RangeBounds},
};
use chrono::{DateTime, Utc};
use rand::{Rng, seq::SliceRandom};
use serde::{Deserialize, Serialize};
@ -20,6 +21,7 @@ use crate::{
night::{Night, ServerAction},
story::{DayDetail, GameActions, GameStory, NightDetails},
},
id::GameId,
message::{
CharacterState, Identification,
host::{HostDayMessage, HostGameMessage, HostNightMessage, ServerToHostMessage},
@ -36,6 +38,8 @@ type Result<T> = core::result::Result<T, GameError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Game {
id: GameId,
started_at: DateTime<Utc>,
history: GameStory,
state: GameState,
}
@ -44,12 +48,36 @@ impl Game {
pub fn new(players: &[Identification], settings: GameSettings) -> Result<Self> {
let village = Village::new(players, settings)?;
Ok(Self {
id: GameId::new(),
started_at: Utc::now(),
history: GameStory::new(village.clone()),
state: GameState::Night {
night: Night::new(village)?,
},
})
}
#[cfg(feature = "server")]
pub const fn new_from_parts(
id: GameId,
started_at: DateTime<Utc>,
history: GameStory,
state: GameState,
) -> Self {
Self {
id,
started_at,
history,
state,
}
}
pub const fn game_id(&self) -> GameId {
self.id
}
pub const fn started_at(&self) -> DateTime<Utc> {
self.started_at
}
pub const fn village(&self) -> &Village {
match &self.state {

View File

@ -0,0 +1 @@
crate::id_impl!(GameId);

View File

@ -1,13 +1,109 @@
#![allow(clippy::new_without_default)]
#[cfg(feature = "server")]
pub mod cbor;
pub mod character;
pub mod diedto;
pub mod error;
pub mod game;
#[cfg(test)]
mod game_test;
pub mod id;
pub mod limited;
pub mod message;
pub mod modifier;
pub mod nonzero;
pub mod player;
pub mod role;
pub mod token;
pub mod user;
#[macro_export]
macro_rules! id_impl {
($name:ident) => {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct $name(uuid::Uuid);
#[cfg(feature = "server")]
impl sqlx::TypeInfo for $name {
fn is_null(&self) -> bool {
self.0 == uuid::Uuid::nil()
}
fn name(&self) -> &str {
"uuid"
}
}
#[cfg(feature = "server")]
impl sqlx::Type<sqlx::Postgres> for $name {
fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
<uuid::Uuid as sqlx::Type<sqlx::Postgres>>::type_info()
}
}
#[cfg(feature = "server")]
impl<'q> sqlx::Encode<'q, sqlx::Postgres> for $name {
fn encode_by_ref(
&self,
buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>,
) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
self.0.encode_by_ref(buf)
}
}
#[cfg(feature = "server")]
impl<'r> sqlx::Decode<'r, sqlx::Postgres> for $name {
fn decode(
value: <sqlx::Postgres as sqlx::Database>::ValueRef<'r>,
) -> Result<Self, sqlx::error::BoxDynError> {
Ok(Self(uuid::Uuid::decode(value)?))
}
}
impl From<uuid::Uuid> for $name {
fn from(value: uuid::Uuid) -> Self {
Self::from_uuid(value)
}
}
impl From<$name> for uuid::Uuid {
fn from(value: $name) -> Self {
value.into_uuid()
}
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
impl $name {
pub fn new() -> Self {
Self(uuid::Uuid::new_v4())
}
pub const fn from_uuid(uuid: uuid::Uuid) -> Self {
Self(uuid)
}
pub const fn into_uuid(self) -> uuid::Uuid {
self.0
}
}
impl core::fmt::Display for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl core::str::FromStr for $name {
type Err = uuid::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(uuid::Uuid::from_str(s)?))
}
}
};
}

View File

@ -0,0 +1,129 @@
use core::{
fmt::Display,
ops::{Deref, RangeInclusive},
};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct FixedLenString<const LEN: usize>(String);
impl<const LEN: usize> Display for FixedLenString<LEN> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<const LEN: usize> FixedLenString<LEN> {
pub fn new(s: String) -> Option<Self> {
(s.chars().take(LEN + 1).count() == LEN).then_some(Self(s))
}
}
impl<const LEN: usize> Deref for FixedLenString<LEN> {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de, const LEN: usize> Deserialize<'de> for FixedLenString<LEN> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ExpectedLen(usize);
impl serde::de::Expected for ExpectedLen {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a string exactly {} characters long", self.0)
}
}
<String as Deserialize>::deserialize(deserializer).and_then(|s| {
let char_count = s.chars().take(LEN.saturating_add(1)).count();
if char_count != LEN {
Err(serde::de::Error::invalid_length(
char_count,
&ExpectedLen(LEN),
))
} else {
Ok(Self(s))
}
})
}
}
impl<const LEN: usize> Serialize for FixedLenString<LEN> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.0.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ClampedString<const MIN: usize, const MAX: usize>(String);
impl<const MIN: usize, const MAX: usize> ClampedString<MIN, MAX> {
pub fn new(s: String) -> Result<Self, RangeInclusive<usize>> {
let str_len = s.chars().take(MAX.saturating_add(1)).count();
(str_len >= MIN && str_len <= MAX)
.then_some(Self(s))
.ok_or(MIN..=MAX)
}
pub fn into_inner(self) -> String {
self.0
}
}
impl<const MIN: usize, const MAX: usize> Display for ClampedString<MIN, MAX> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<const MIN: usize, const MAX: usize> Deref for ClampedString<MIN, MAX> {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de, const MIN: usize, const MAX: usize> Deserialize<'de> for ClampedString<MIN, MAX> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ExpectedLen(usize, usize);
impl serde::de::Expected for ExpectedLen {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"a string between {} and {} characters long",
self.0, self.1
)
}
}
<String as Deserialize>::deserialize(deserializer).and_then(|s| {
let char_count = s.chars().take(MAX.saturating_add(1)).count();
if char_count < MIN || char_count > MAX {
Err(serde::de::Error::invalid_length(
char_count,
&ExpectedLen(MIN, MAX),
))
} else {
Ok(Self(s))
}
})
}
}
impl<const MIN: usize, const MAX: usize> Serialize for ClampedString<MIN, MAX> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.0.as_str())
}
}

View File

@ -17,6 +17,12 @@ impl PlayerId {
pub const fn from_u128(v: u128) -> Self {
Self(uuid::Uuid::from_u128(v))
}
pub const fn from_uuid(v: uuid::Uuid) -> Self {
Self(v)
}
pub const fn into_uuid(self) -> uuid::Uuid {
self.0
}
}
impl Display for PlayerId {

View File

@ -0,0 +1,40 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::{limited::FixedLenString, user::Username};
pub const TOKEN_LEN: usize = 0x20;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Token {
pub token: FixedLenString<TOKEN_LEN>,
pub username: Username,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl Token {
pub fn login_token(&self) -> TokenLogin {
TokenLogin(self.token.clone())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TokenLogin(pub FixedLenString<TOKEN_LEN>);
#[cfg(feature = "server")]
impl axum_extra::headers::authorization::Credentials for TokenLogin {
const SCHEME: &'static str = "Bearer";
fn decode(value: &axum::http::HeaderValue) -> Option<Self> {
value
.to_str()
.ok()
.and_then(|v| FixedLenString::new(v.strip_prefix("Bearer ").unwrap_or(v).to_string()))
.map(Self)
}
fn encode(&self) -> axum::http::HeaderValue {
axum::http::HeaderValue::from_str(self.0.as_str()).expect("bearer token encode")
}
}

View File

@ -0,0 +1,14 @@
use serde::{Deserialize, Serialize};
use crate::limited::ClampedString;
pub type Username = ClampedString<1, 0x40>;
pub type Password = ClampedString<6, 0x100>;
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct UserLogin {
pub username: Username,
pub password: Password,
}
crate::id_impl!(UserId);

View File

@ -11,12 +11,12 @@ pretty_env_logger = { version = "0.5" }
# env_logger = { version = "0.11" }
futures = "0.3.31"
anyhow = { version = "1" }
werewolves-proto = { path = "../werewolves-proto" }
werewolves-proto = { path = "../werewolves-proto", features = ["server"] }
werewolves-macros = { path = "../werewolves-macros" }
mime-sniffer = { version = "0.1" }
chrono = { version = "0.4" }
atom_syndication = { version = "0.12" }
axum-extra = { version = "0.10", features = ["typed-header"] }
axum-extra = { version = "0.12", features = ["typed-header"] }
rand = { version = "0.9" }
serde_json = { version = "1.0" }
serde = { version = "1.0", features = ["derive"] }
@ -24,8 +24,27 @@ thiserror = { version = "2" }
ciborium = { version = "0.2", optional = true }
colored = { version = "3.0" }
fast_qr = { version = "0.13", features = ["svg"] }
ron = "0.8"
ron = "0.11"
bytes = { version = "1.10" }
sqlx = { version = "0.8", features = [
"runtime-tokio",
"postgres",
"derive",
"macros",
"uuid",
"chrono",
] }
argon2 = { version = "0.5" }
tower-http = { version = "0.6", features = ["cors"] }
tower = { version = "0.5.2", features = [
"limit",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"buffer",
"timeout",
] }
[features]

View File

@ -0,0 +1,150 @@
use sqlx::{Pool, Postgres, query};
use werewolves_proto::{
error::DatabaseError,
game::{Game, GameOver, story::GameStory},
id::GameId,
player::PlayerId,
user::UserId,
};
#[derive(Debug, Clone)]
pub struct GameDatabase {
pub(super) pool: Pool<Postgres>,
}
impl GameDatabase {
pub async fn new_game(&self, game: &Game) -> Result<(), DatabaseError> {
let state = serde_json::to_value(game.game_state())?;
let story = serde_json::to_value(&game.story())?;
let mut tx = self.pool.begin().await?;
query!(
r#" insert into
games (id, started_at, state, story)
values
($1, $2, $3, $4)"#,
game.game_id().into_uuid(),
game.started_at(),
state,
story,
)
.execute(&mut *tx)
.await?;
let player_ids = game
.village()
.characters()
.into_iter()
.map(|c| c.player_id().into_uuid())
.collect::<Box<[_]>>();
let user_ids = query!(
r#" select
id, user_id
from
players
where
id = any($1::uuid[])"#,
&*player_ids,
)
.fetch_all(&mut *tx)
.await?
.into_iter()
.map(|r| (PlayerId::from_uuid(r.id), r.user_id.map(UserId::from_uuid)))
.unzip::<PlayerId, Option<UserId>, Vec<PlayerId>, Vec<Option<UserId>>>();
let game_id = game.game_id().into_uuid();
let game_ids = (0..player_ids.len()).map(|_| game_id).collect::<Box<[_]>>();
query!(
r#" with
game_ids as (select row_number() over(), * from unnest($1::uuid[]) as game_id),
player_ids as (select row_number() over(), * from unnest($2::uuid[]) as player_id)
insert into
game_players
select
game_ids.game_id, player_ids.player_id
from
game_ids
join
player_ids on game_ids.row_number = player_ids.row_number
"#,
&*game_ids,
&*player_ids,
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
pub async fn update_game(&self, game: &Game) -> Result<(), DatabaseError> {
let state = serde_json::to_value(game.game_state())?;
let story = serde_json::to_value(game.story())?;
query!(
r#" update
games
set
story = $2,
state = $3,
outcome = $4,
updated_at = now()
where
id = $1"#,
game.game_id().into_uuid(),
story,
state,
game.game_over().map(Self::outcome_to_db_outcome) as _
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn get_active_game(&self) -> Result<Game, DatabaseError> {
let game = query!(
r#" select
id, state, story, started_at
from
games
where
outcome is null"#
)
.fetch_one(&self.pool)
.await?;
Ok(Game::new_from_parts(
GameId::from_uuid(game.id),
game.started_at,
serde_json::from_value(game.story)?,
serde_json::from_value(game.state)?,
))
}
pub async fn get_game_story(&self, id: GameId) -> Result<GameStory, DatabaseError> {
let game = query!(
r#" select
story
from
games
where
id = $1
and
outcome is not null"#,
id.into_uuid(),
)
.fetch_one(&self.pool)
.await?;
Ok(serde_json::from_value(game.story)?)
}
fn outcome_to_db_outcome(outcome: GameOver) -> &'static str {
match outcome {
GameOver::VillageWins => "village_victory",
GameOver::WolvesWin => "wolves_victory",
}
}
}

View File

@ -0,0 +1,40 @@
pub mod game;
pub mod user;
use sqlx::{Pool, Postgres};
use werewolves_proto::error::DatabaseError;
use crate::db::{game::GameDatabase, user::UserDatabase};
type Result<T> = core::result::Result<T, DatabaseError>;
#[derive(Debug, Clone)]
pub struct Database {
pool: Pool<Postgres>,
}
impl Database {
pub const fn new(pool: Pool<Postgres>) -> Self {
Self { pool }
}
pub fn user(&self) -> UserDatabase {
UserDatabase {
pool: self.pool.clone(),
}
}
pub fn game(&self) -> GameDatabase {
GameDatabase {
pool: self.pool.clone(),
}
}
pub async fn migrate(&self) {
log::info!("running migrations");
sqlx::migrate!("../migrations")
.run(&self.pool)
.await
.expect("run migrations");
log::info!("migrations done");
}
}

View File

@ -0,0 +1,214 @@
use super::Result;
use argon2::{
Argon2, PasswordHash, PasswordVerifier,
password_hash::{PasswordHasher, SaltString, rand_core::OsRng},
};
use chrono::{TimeDelta, Utc};
use rand::distr::SampleString;
use sqlx::{Decode, Encode, Pool, Postgres, prelude::FromRow, query, query_as};
use werewolves_proto::{
error::{DatabaseError, ServerError},
token,
user::UserId,
};
#[derive(Debug, Clone)]
pub struct UserDatabase {
pub(super) pool: Pool<Postgres>,
}
#[derive(Debug, Clone, FromRow)]
pub struct LoginToken {
pub token: String,
pub user_id: UserId,
pub created_at: chrono::DateTime<Utc>,
pub expires_at: chrono::DateTime<Utc>,
}
impl LoginToken {
const TOKEN_LONGEVITY: TimeDelta = TimeDelta::days(30);
pub fn new(user_id: UserId) -> Self {
let created_at = Utc::now();
let expires_at = created_at
.checked_add_signed(Self::TOKEN_LONGEVITY)
.unwrap_or_else(|| {
panic!(
"could not add {} time to {created_at}",
Self::TOKEN_LONGEVITY
)
});
let token = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), token::TOKEN_LEN);
Self {
token,
user_id,
created_at,
expires_at,
}
}
}
pub enum GetUserBy<'a> {
Username(&'a str),
Id(UserId),
}
impl UserDatabase {
pub async fn create(&self, username: &str, password: &str) -> Result<User> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)?
.to_string();
let now = chrono::offset::Utc::now();
let user = User {
id: UserId::new(),
username: username.into(),
password_hash,
created_at: now,
updated_at: now,
};
query!(
r#"insert into users
(id, username, password_hash, created_at, updated_at)
values
($1, $2, $3, $4, $5)"#,
user.id.into_uuid(),
user.username,
user.password_hash,
user.created_at,
user.updated_at
)
.execute(&self.pool)
.await
.map_err(|err| {
if let sqlx::Error::Database(db_err) = &err
&& let Some(constraint) = db_err.constraint()
&& constraint == "users_username_unique"
{
DatabaseError::UserAlreadyExists
} else {
err.into()
}
})?;
Ok(user)
}
pub async fn get_user(&self, get_user_by: GetUserBy<'_>) -> Result<User> {
Ok(match get_user_by {
GetUserBy::Username(username) => {
query_as!(
User,
r#"
select
id, username, password_hash,
created_at, updated_at
from
users
where
username = $1"#,
username
)
.fetch_one(&self.pool)
.await?
}
GetUserBy::Id(id) => {
query_as!(
User,
r#"
select
id, username, password_hash,
created_at, updated_at
from
users
where
id = $1"#,
id.into_uuid()
)
.fetch_one(&self.pool)
.await?
}
})
}
pub async fn login(
&self,
username: &str,
password: &str,
) -> core::result::Result<LoginToken, ServerError> {
let user = self.get_user(GetUserBy::Username(username)).await?;
let parsed_hash = PasswordHash::new(&user.password_hash).map_err(DatabaseError::from)?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|err| match err {
argon2::password_hash::Error::Password => ServerError::InvalidCredentials,
err => ServerError::DatabaseError(err.into()),
})?;
let token = LoginToken::new(user.id);
query!(
r#" insert into login_tokens
(token, user_id, created_at, expires_at)
values
($1, $2, $3, $4)"#,
token.token,
token.user_id.into_uuid(),
token.created_at,
token.expires_at
)
.execute(&self.pool)
.await
.map_err(Into::<DatabaseError>::into)?;
Ok(token)
}
pub async fn check_token(&self, token: &str) -> core::result::Result<User, ServerError> {
let token = query_as!(
LoginToken,
r#" select
token, user_id, created_at, expires_at
from
login_tokens
where
token = $1
and
expires_at > now()
"#,
token
)
.fetch_one(&self.pool)
.await
.map_err(Into::<DatabaseError>::into)
.map_err(|err| match err {
DatabaseError::NotFound => ServerError::ExpiredToken,
_ => err.into(),
})?;
if Utc::now() >= token.expires_at {
return Err(ServerError::ExpiredToken);
}
Ok(self.get_user(GetUserBy::Id(token.user_id)).await?)
}
}
#[derive(Debug, Clone, FromRow, Encode, Decode)]
pub struct User {
pub id: UserId,
pub username: String,
pub password_hash: String,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}

View File

@ -1,37 +1,56 @@
mod client;
mod communication;
mod connection;
mod db;
mod game;
mod host;
mod lobby;
mod runner;
mod saver;
// mod saver;
use axum::{
Router,
BoxError, Router,
error_handling::HandleErrorLayer,
extract::{Path, State},
http::{Request, StatusCode, header},
response::IntoResponse,
routing::{any, get},
routing::{any, get, post, put},
};
use axum_extra::{
TypedHeader,
headers::{self, Authorization},
};
use axum_extra::headers;
use communication::lobby::LobbyComms;
use connection::JoinedPlayers;
use core::{fmt::Display, net::SocketAddr, str::FromStr};
use core::{fmt::Display, net::SocketAddr, str::FromStr, time::Duration};
use fast_qr::convert::{Builder, Shape, svg::SvgBuilder};
use runner::IdentifiedClientMessage;
use std::{env, io::Write, path::Path};
use sqlx::postgres::PgPoolOptions;
use std::{env, io::Write};
use tokio::sync::{broadcast, mpsc};
use tower::{ServiceBuilder, buffer::BufferLayer, limit::RateLimitLayer};
use werewolves_proto::{
cbor::Cbor,
error::ServerError,
id::GameId,
limited::FixedLenString,
token::{Token, TokenLogin},
user::UserLogin,
};
use crate::{
communication::{Comms, connect::ConnectUpdate, host::HostComms, player::PlayerIdComms},
saver::FileSaver,
db::Database,
// saver::FileSaver,
};
const DEFAULT_PORT: u16 = 8080;
const DEFAULT_HOST: &str = "127.0.0.1";
const DEFAULT_SAVE_DIR: &str = "werewolves-saves/";
const DEFAULT_QRCODE_URL: &str = "https://wolf.emilis.dev/";
const DEFAULT_MAX_PG_CONNECTIONS: u32 = 30;
const DEFAULT_PG_CONN_STRING: &str = "postgres:///ww?host=/var/run/postgresql";
#[tokio::main]
async fn main() {
// pretty_env_logger::init();
@ -106,39 +125,60 @@ async fn main() {
let jp_clone = joined_players.clone();
let path = Path::new(option_env!("SAVE_PATH").unwrap_or(DEFAULT_SAVE_DIR));
let pg_pool = PgPoolOptions::new()
.max_connections(
std::env::var("MAX_DB_CONNECTIONS")
.ok()
.and_then(|val| u32::from_str(&val).ok())
.unwrap_or(DEFAULT_MAX_PG_CONNECTIONS),
)
.connect(
std::env::var("PG_CONN_STRING")
.unwrap_or_else(|_| String::from(DEFAULT_PG_CONN_STRING))
.as_str(),
)
.await
.expect("could not init db");
let db = Database::new(pg_pool);
db.migrate().await;
if let Err(err) = std::fs::create_dir(path)
&& !matches!(err.kind(), std::io::ErrorKind::AlreadyExists)
{
panic!("creating save dir at [{path:?}]: {err}")
}
// Check if we can write to the path
{
let test_file_path = path.join(".test");
if let Err(err) = std::fs::File::create(&test_file_path) {
panic!("can't create files in {path:?}: {err}")
// let saver = FileSaver::new(path.canonicalize().expect("canonicalizing path"));
tokio::spawn({
let db = db.clone();
async move {
crate::runner::run_game(jp_clone, lobby_comms, db).await;
panic!("game over");
}
std::fs::remove_file(&test_file_path).log_err();
}
let saver = FileSaver::new(path.canonicalize().expect("canonicalizing path"));
tokio::spawn(async move {
crate::runner::run_game(jp_clone, lobby_comms, saver).await;
panic!("game over");
});
let state = AppState {
joined_players,
host_recv,
host_send,
send,
db,
};
let app = Router::new()
.route("/connect/client", any(client::handler))
.route("/connect/host", any(host::handler))
.route("/qrcode", get(handle_qr_code))
.route("/s/users", put(signup))
.route("/s/tokens", post(signin))
.route(
"/s/tokens/check",
get(check_token).layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|err: BoxError| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled error: {}", err),
)
}))
.layer(BufferLayer::new(0x100))
.layer(RateLimitLayer::new(100, Duration::from_secs(10))),
),
)
.route("/s/games/{id}", get(get_game_by_id))
.with_state(state)
.fallback(get(handle_http_static));
let listener = tokio::net::TcpListener::bind(listen_addr).await.unwrap();
@ -151,15 +191,61 @@ async fn main() {
.unwrap();
}
async fn get_game_by_id(
State(AppState { db, .. }): State<AppState>,
Path(game_id): Path<GameId>,
) -> Result<impl IntoResponse, ServerError> {
let story = db.game().get_game_story(game_id).await?;
Ok(Cbor(story))
}
async fn check_token(
State(AppState { db, .. }): State<AppState>,
TypedHeader(Authorization(login)): TypedHeader<Authorization<TokenLogin>>,
) -> Result<impl IntoResponse, ServerError> {
db.user().check_token(&login.0).await?;
Ok(StatusCode::OK)
}
async fn signin(
State(AppState { db, .. }): State<AppState>,
Cbor(UserLogin { username, password }): Cbor<UserLogin>,
) -> Result<impl IntoResponse, ServerError> {
let token = db.user().login(&username, &password).await?;
Ok(Cbor(Token {
username,
token: FixedLenString::new(token.token.clone()).ok_or_else(|| {
ServerError::InternalServerError(format!(
"could not get a fixed len string for token [{}]",
token.token
))
})?,
created_at: token.created_at,
expires_at: token.expires_at,
})
.into_response())
}
async fn signup(
State(AppState { db, .. }): State<AppState>,
Cbor(UserLogin { username, password }): Cbor<UserLogin>,
) -> Result<impl IntoResponse, ServerError> {
db.user().create(&username, &password).await?;
Ok(StatusCode::CREATED)
}
struct AppState {
joined_players: JoinedPlayers,
send: broadcast::Sender<IdentifiedClientMessage>,
host_send: tokio::sync::mpsc::Sender<werewolves_proto::message::host::HostMessage>,
host_recv: broadcast::Receiver<werewolves_proto::message::host::ServerToHostMessage>,
db: Database,
}
impl Clone for AppState {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
joined_players: self.joined_players.clone(),
send: self.send.clone(),
host_send: self.host_send.clone(),

View File

@ -2,7 +2,11 @@ use core::{num::NonZeroU8, time::Duration};
use std::sync::Arc;
use werewolves_proto::{
message::{ClientMessage, Identification, host::HostMessage},
error::{GameError, ServerError},
message::{
ClientMessage, Identification,
host::{HostMessage, ServerToHostMessage},
},
player::PlayerId,
};
@ -10,9 +14,9 @@ use crate::{
LogError,
communication::lobby::LobbyComms,
connection::JoinedPlayers,
db::Database,
game::{GameEnd, GameRunner},
lobby::{Lobby, LobbyPlayers},
saver::Saver,
};
#[derive(Debug, Clone, PartialEq)]
@ -21,7 +25,7 @@ pub struct IdentifiedClientMessage {
pub message: ClientMessage,
}
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut saver: impl Saver) {
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, db: Database) {
let mut lobby = Lobby::new(joined_players, comms);
if let Some(dummies) = option_env!("DUMMY_PLAYERS").and_then(|p| p.parse::<NonZeroU8>().ok()) {
log::info!("creating {dummies} dummy players");
@ -32,36 +36,37 @@ pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut save
loop {
match &mut state {
RunningState::Lobby(lobby) => {
if let Some(game) = lobby.next().await {
if let Some(mut game) = lobby.next().await {
if let Err(err) = db.game().new_game(game.proto_game()).await {
log::error!("saving new game: {err}; reverting to lobby");
game.comms()
.host()
.send(ServerToHostMessage::Error(GameError::ServerError(
ServerError::DatabaseError(err),
)))
.log_err();
state = RunningState::Lobby(game.into_lobby());
continue;
}
state = RunningState::Game(game)
}
}
RunningState::Game(game) => {
if let Some(result) = game.next().await {
match saver.save(game.proto_game()) {
Ok(path) => {
log::info!("saved game to {path}");
}
Err(err) => {
log::error!("saving game: {err}");
let game_clone = game.proto_game().clone();
let mut saver_clone = saver.clone();
tokio::spawn(async move {
let started = chrono::Utc::now();
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
match saver_clone.save(&game_clone) {
Ok(path) => {
log::info!("saved game from {started} to {path}");
return;
}
Err(err) => {
log::error!("saving game from {started}: {err}")
}
}
if let Err(err) = db.game().update_game(game.proto_game()).await {
log::error!("saving game ({}): {err}", game.proto_game().game_id());
let game_clone = game.proto_game().clone();
let db_clone = db.game();
tokio::spawn(async move {
let started = chrono::Utc::now();
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
if let Err(err) = db_clone.update_game(&game_clone).await {
log::error!("saving game from {started}: {err}")
}
});
}
}
});
}
state = match state {
RunningState::Game(game) => {
@ -69,6 +74,10 @@ pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut save
}
_ => unsafe { core::hint::unreachable_unchecked() },
};
} else {
if let Err(err) = db.game().update_game(game.proto_game()).await {
log::error!("updating game ({}): {err}", game.proto_game().game_id());
}
}
}
RunningState::GameOver(end) => {

View File

@ -1,7 +1,8 @@
[build]
target = "index.html" # The index HTML file to drive the bundling process.
html_output = "index.html" # The name of the output HTML file.
release = true # Build in release mode.
release = false # Build in release mode.
# release = true # Build in release mode.
dist = "dist" # The output dir for all final assets.
public_url = "/" # The public URL from which assets are to be served.
filehash = true # Whether to include hash values in the output file names.
@ -9,6 +10,6 @@ inject_scripts = true # Whether to inject scripts (and module preloads) in
offline = false # Run without network access
frozen = false # Require Cargo.lock and cache are up to date
locked = false # Require Cargo.lock is up to date
# minify = "on_release" # Control minification: can be one of: never, on_release, always
minify = "always" # Control minification: can be one of: never, on_release, always
minify = "on_release" # Control minification: can be one of: never, on_release, always
# minify = "always" # Control minification: can be one of: never, on_release, always
no_sri = false # Allow disabling sub-resource integrity (SRI)