fix: no longer self-wake future checking for connects

This commit is contained in:
emilis 2025-12-09 20:50:08 +00:00
parent 9b989389b5
commit f6cf92bc40
No known key found for this signature in database
11 changed files with 350 additions and 593 deletions

667
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,18 +5,16 @@ edition = "2024"
[dependencies] [dependencies]
axum = { version = "0.8", features = ["ws"] } axum = { version = "0.8", features = ["ws"] }
tokio = { version = "1.47", features = ["full"] } tokio = { version = "1.48", features = ["full"] }
log = { version = "0.4" } log = { version = "0.4" }
pretty_env_logger = { version = "0.5" } pretty_env_logger = { version = "0.5" }
# env_logger = { version = "0.11" }
futures = "0.3.31" futures = "0.3.31"
anyhow = { version = "1" } anyhow = { version = "1" }
werewolves-proto = { path = "../werewolves-proto" } werewolves-proto = { path = "../werewolves-proto" }
werewolves-macros = { path = "../werewolves-macros" } werewolves-macros = { path = "../werewolves-macros" }
mime-sniffer = { version = "0.1" } mime-sniffer = { version = "0.1" }
chrono = { version = "0.4" } chrono = { version = "0.4" }
atom_syndication = { version = "0.12" } axum-extra = { version = "0.12", features = ["typed-header"] }
axum-extra = { version = "0.10", features = ["typed-header"] }
rand = { version = "0.9" } rand = { version = "0.9" }
serde_json = { version = "1.0" } serde_json = { version = "1.0" }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }

View File

@ -15,9 +15,9 @@
use core::{net::SocketAddr, time::Duration}; use core::{net::SocketAddr, time::Duration};
use crate::{ use crate::{
AppState, XForwardedFor, AppState, LogError, XForwardedFor,
connection::{ConnectionId, JoinedPlayer}, connection::{ConnectionId, JoinedPlayer},
runner::IdentifiedClientMessage, runner::{ClientUpdate, IdentifiedClientMessage},
}; };
use axum::{ use axum::{
extract::{ extract::{
@ -73,20 +73,33 @@ pub async fn handler(
) )
.await .await
}; };
state
.send
.send(IdentifiedClientMessage {
identity: ident.clone(),
update: ClientUpdate::ConnectStateUpdate,
})
.log_debug();
Client::new( Client::new(
ident.clone(), ident.clone(),
connection_id.clone(), connection_id.clone(),
socket, socket,
who.to_string(), who.to_string(),
state.send, state.send.clone(),
recv, recv,
) )
.run() .run()
.await; .await;
// log::debug!("ending connection with {who}");
player_list.disconnect(&connection_id).await; player_list.disconnect(&connection_id).await;
state
.send
.send(IdentifiedClientMessage {
identity: ident.clone(),
update: ClientUpdate::ConnectStateUpdate,
})
.log_debug();
}) })
} }
@ -220,7 +233,7 @@ impl Client {
} }
self.sender.send(IdentifiedClientMessage { self.sender.send(IdentifiedClientMessage {
message, update: ClientUpdate::Message(message),
identity: self.ident.clone(), identity: self.ident.clone(),
})?; })?;

View File

@ -1,74 +0,0 @@
// Copyright (C) 2025 Emilis Bliūdžius
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use core::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Mutex;
use werewolves_proto::player::PlayerId;
#[derive(Debug, Clone)]
pub struct ConnectUpdate {
updated: Arc<AtomicBool>,
connected: Arc<Mutex<Vec<PlayerId>>>,
}
impl ConnectUpdate {
pub fn new() -> Self {
Self {
updated: Arc::new(AtomicBool::new(false)),
connected: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn connect(&self, pid: PlayerId) {
let mut connected = self.connected.lock().await;
if connected.iter().any(|c| c == &pid) {
return;
}
connected.push(pid);
self.updated.store(true, Ordering::SeqCst);
}
pub async fn disconnect(&self, pid: PlayerId) {
let mut connected = self.connected.lock().await;
if let Some(idx) = connected
.iter()
.enumerate()
.find_map(|(idx, c)| (c == &pid).then_some(idx))
{
connected.swap_remove(idx);
self.updated.store(true, Ordering::SeqCst);
}
}
}
impl Future for ConnectUpdate {
type Output = Arc<[PlayerId]>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if self.updated.load(Ordering::SeqCst)
&& let Ok(connected) = self.connected.try_lock()
{
self.updated.store(false, Ordering::SeqCst);
std::task::Poll::Ready(connected.clone().into())
} else {
cx.waker().wake_by_ref();
std::task::Poll::Pending
}
}
}

View File

@ -14,29 +14,21 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
use werewolves_proto::error::GameError; use werewolves_proto::error::GameError;
use crate::{ use crate::{communication::Comms, runner::Message};
communication::{Comms, connect::ConnectUpdate},
runner::Message,
};
use super::{HostComms, player::PlayerIdComms}; use super::{HostComms, player::PlayerIdComms};
pub struct LobbyComms { pub struct LobbyComms {
comms: Comms, comms: Comms,
// TODO: move this to not use a receiver
connect_recv: ConnectUpdate,
} }
impl LobbyComms { impl LobbyComms {
pub fn new(comms: Comms, connect_recv: ConnectUpdate) -> Self { pub fn new(comms: Comms) -> Self {
Self { Self { comms }
comms,
connect_recv,
}
} }
pub fn into_inner(self) -> (Comms, ConnectUpdate) { pub fn into_inner(self) -> Comms {
(self.comms, self.connect_recv) self.comms
} }
#[allow(unused)] #[allow(unused)]
@ -49,17 +41,12 @@ impl LobbyComms {
} }
pub async fn next_message(&mut self) -> Result<Message, GameError> { pub async fn next_message(&mut self) -> Result<Message, GameError> {
tokio::select! { match self.comms.message().await {
r = self.comms.message() => { Ok(val) => Ok(val),
match r { Err(GameError::GenericError(err)) => {
Ok(val) => Ok(val), Err(GameError::GenericError(format!("comms message: {err}")))
Err(GameError::GenericError(err)) => Err(GameError::GenericError(format!("comms message: {err}"))), }
Err(err) => Err(err), Err(err) => Err(err),
}
}
r = self.connect_recv.clone() => {
Ok(Message::ConnectedList(r))
}
} }
} }
} }

View File

@ -19,7 +19,6 @@ use crate::{
runner::Message, runner::Message,
}; };
pub mod connect;
pub mod host; pub mod host;
pub mod lobby; pub mod lobby;
pub mod player; pub mod player;

View File

@ -27,7 +27,7 @@ use werewolves_proto::{
player::PlayerId, player::PlayerId,
}; };
use crate::{LogError, communication::connect::ConnectUpdate}; use crate::LogError;
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConnectionId(PlayerId, Instant); pub struct ConnectionId(PlayerId, Instant);
@ -80,13 +80,11 @@ impl JoinedPlayer {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct JoinedPlayers { pub struct JoinedPlayers {
players: Arc<Mutex<HashMap<PlayerId, JoinedPlayer>>>, players: Arc<Mutex<HashMap<PlayerId, JoinedPlayer>>>,
connect_state: ConnectUpdate,
} }
impl JoinedPlayers { impl JoinedPlayers {
pub fn new(connect_state: ConnectUpdate) -> Self { pub fn new() -> Self {
Self { Self {
connect_state,
players: Arc::new(Mutex::new(HashMap::new())), players: Arc::new(Mutex::new(HashMap::new())),
} }
} }
@ -172,8 +170,6 @@ impl JoinedPlayers {
pub async fn disconnect(&self, connection: &ConnectionId) -> Option<JoinedPlayer> { pub async fn disconnect(&self, connection: &ConnectionId) -> Option<JoinedPlayer> {
let mut map = self.players.lock().await; let mut map = self.players.lock().await;
self.connect_state.disconnect(connection.0).await;
if map if map
.get(connection.player_id()) .get(connection.player_id())
.map(|p| p.active_connection == *connection && !p.in_game) .map(|p| p.active_connection == *connection && !p.in_game)
@ -206,8 +202,6 @@ impl JoinedPlayers {
old.receiver old.receiver
} else { } else {
self.connect_state.connect(player_id).await;
unsafe { map.get(&player_id).unwrap_unchecked() } unsafe { map.get(&player_id).unwrap_unchecked() }
.receiver .receiver
.resubscribe() .resubscribe()

View File

@ -13,14 +13,13 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
use core::{ops::Not, time::Duration}; use core::{ops::Not, time::Duration};
use std::sync::Arc;
use crate::{ use crate::{
LogError, LogError,
communication::{Comms, connect::ConnectUpdate, lobby::LobbyComms}, communication::{Comms, lobby::LobbyComms},
connection::JoinedPlayers, connection::JoinedPlayers,
lobby::{Lobby, LobbyPlayers}, lobby::{Lobby, LobbyPlayers},
runner::{IdentifiedClientMessage, Message}, runner::{ClientUpdate, IdentifiedClientMessage, Message},
}; };
use tokio::time::Instant; use tokio::time::Instant;
use werewolves_proto::{ use werewolves_proto::{
@ -39,7 +38,6 @@ type Result<T> = core::result::Result<T, GameError>;
pub struct GameRunner { pub struct GameRunner {
game: Game, game: Game,
comms: Comms, comms: Comms,
connect_recv: ConnectUpdate,
player_sender: LobbyPlayers, player_sender: LobbyPlayers,
roles_revealed: bool, roles_revealed: bool,
joined_players: JoinedPlayers, joined_players: JoinedPlayers,
@ -50,13 +48,11 @@ impl GameRunner {
game: Game, game: Game,
comms: Comms, comms: Comms,
player_sender: LobbyPlayers, player_sender: LobbyPlayers,
connect_recv: ConnectUpdate,
joined_players: JoinedPlayers, joined_players: JoinedPlayers,
) -> Self { ) -> Self {
Self { Self {
game, game,
comms, comms,
connect_recv,
player_sender, player_sender,
joined_players, joined_players,
roles_revealed: false, roles_revealed: false,
@ -68,10 +64,7 @@ impl GameRunner {
} }
pub fn into_lobby(self) -> Lobby { pub fn into_lobby(self) -> Lobby {
let mut lobby = Lobby::new( let mut lobby = Lobby::new(self.joined_players, LobbyComms::new(self.comms));
self.joined_players,
LobbyComms::new(self.comms, self.connect_recv),
);
lobby.set_settings(self.game.village().settings()); lobby.set_settings(self.game.village().settings());
lobby.set_players_in_lobby(self.player_sender); lobby.set_players_in_lobby(self.player_sender);
lobby lobby
@ -155,7 +148,6 @@ impl GameRunner {
}; };
let mut last_err_log = tokio::time::Instant::now() - tokio::time::Duration::from_secs(60); let mut last_err_log = tokio::time::Instant::now() - tokio::time::Duration::from_secs(60);
let mut connect_list: Arc<[PlayerId]> = Arc::new([]);
while acks.iter().any(|(_, ackd)| !*ackd) { while acks.iter().any(|(_, ackd)| !*ackd) {
const PING_TIME: Duration = Duration::from_secs(1); const PING_TIME: Duration = Duration::from_secs(1);
let sleep_fut = tokio::time::sleep(PING_TIME); let sleep_fut = tokio::time::sleep(PING_TIME);
@ -198,7 +190,7 @@ impl GameRunner {
player_id, player_id,
public: _, public: _,
}, },
message: ClientMessage::GetState, update: ClientUpdate::Message(ClientMessage::GetState),
}) => { }) => {
let Some(sender) = self.joined_players.get_sender(player_id).await else { let Some(sender) = self.joined_players.get_sender(player_id).await else {
continue; continue;
@ -232,7 +224,7 @@ impl GameRunner {
player_id, player_id,
public: _, public: _,
}, },
message: ClientMessage::RoleAck, update: ClientUpdate::Message(ClientMessage::RoleAck),
}) => { }) => {
if let Some((_, ackd)) = if let Some((_, ackd)) =
acks.iter_mut().find(|(t, _)| t.player_id() == player_id) acks.iter_mut().find(|(t, _)| t.player_id() == player_id)
@ -248,21 +240,9 @@ impl GameRunner {
} }
} }
Message::Client(IdentifiedClientMessage { Message::Client(IdentifiedClientMessage {
identity: identity: Identification { player_id, .. },
Identification { ..
player_id,
public: _,
},
message: _,
}) => (notify_of_role)(player_id, self.game.village(), &self.joined_players).await, }) => (notify_of_role)(player_id, self.game.village(), &self.joined_players).await,
Message::ConnectedList(c) => {
let newly_connected = c.iter().filter(|c| connect_list.contains(*c));
for connected in newly_connected {
(notify_of_role)(*connected, self.game.village(), &self.joined_players)
.await
}
connect_list = c;
}
} }
} }
@ -277,7 +257,10 @@ impl GameRunner {
pub async fn next(&mut self) -> Option<GameOver> { pub async fn next(&mut self) -> Option<GameOver> {
let msg = match self.comms.message().await { let msg = match self.comms.message().await {
Ok(Message::ConnectedList(_)) => return None, Ok(Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
})) => return None,
Ok(Message::Client(IdentifiedClientMessage { Ok(Message::Client(IdentifiedClientMessage {
identity: Identification { player_id, .. }, identity: Identification { player_id, .. },
.. ..
@ -459,6 +442,10 @@ impl GameEnd {
GameError::InvalidMessageForGameState, GameError::InvalidMessageForGameState,
)) ))
.log_debug(), .log_debug(),
Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
}) => {}
Message::Client(IdentifiedClientMessage { identity, .. }) => { Message::Client(IdentifiedClientMessage { identity, .. }) => {
let story = self.game().ok()?.game.story(); let story = self.game().ok()?.game.story();
return Some(ProcessOutcome::SendPlayer( return Some(ProcessOutcome::SendPlayer(
@ -466,7 +453,6 @@ impl GameEnd {
ServerMessage::Story(story), ServerMessage::Story(story),
)); ));
} }
Message::ConnectedList(_) => {}
} }
None None
} }

View File

@ -33,7 +33,7 @@ use crate::{
communication::lobby::LobbyComms, communication::lobby::LobbyComms,
connection::JoinedPlayers, connection::JoinedPlayers,
game::GameRunner, game::GameRunner,
runner::{IdentifiedClientMessage, Message}, runner::{ClientUpdate, IdentifiedClientMessage, Message},
}; };
pub struct Lobby { pub struct Lobby {
@ -147,12 +147,8 @@ impl Lobby {
.log_warn(), .log_warn(),
Err(( Err((
Message::Client(IdentifiedClientMessage { Message::Client(IdentifiedClientMessage {
identity: identity: Identification { player_id, .. },
Identification { ..
player_id,
public: _,
},
message: _,
}), }),
GameError::InvalidMessageForGameState, GameError::InvalidMessageForGameState,
)) => { )) => {
@ -160,10 +156,17 @@ impl Lobby {
.players_in_lobby .players_in_lobby
.send_if_present(player_id, ServerMessage::InvalidMessageForGameState); .send_if_present(player_id, ServerMessage::InvalidMessageForGameState);
} }
Err((
Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
}),
_,
)) => {}
Err(( Err((
Message::Client(IdentifiedClientMessage { Message::Client(IdentifiedClientMessage {
identity: Identification { player_id, public }, identity: Identification { player_id, public },
message: _, ..
}), }),
err, err,
)) => { )) => {
@ -172,7 +175,6 @@ impl Lobby {
.players_in_lobby .players_in_lobby
.send_if_present(player_id, ServerMessage::Reset); .send_if_present(player_id, ServerMessage::Reset);
} }
Err((Message::ConnectedList(_), _)) => {}
} }
None None
} }
@ -249,18 +251,17 @@ impl Lobby {
let game = Game::new(&playing_players, self.settings.clone())?; let game = Game::new(&playing_players, self.settings.clone())?;
assert_eq!(game.village().characters().len(), playing_players.len()); assert_eq!(game.village().characters().len(), playing_players.len());
let (comms, recv) = self.comms.take().unwrap().into_inner(); let comms = self.comms.take().unwrap().into_inner();
return Ok(Some(GameRunner::new( return Ok(Some(GameRunner::new(
game, game,
comms, comms,
self.players_in_lobby.clone(), self.players_in_lobby.clone(),
recv,
self.joined_players.clone(), self.joined_players.clone(),
))); )));
} }
Message::Client(IdentifiedClientMessage { Message::Client(IdentifiedClientMessage {
identity, identity,
message: ClientMessage::Hello, update: ClientUpdate::Message(ClientMessage::Hello),
}) => { }) => {
if self if self
.players_in_lobby .players_in_lobby
@ -278,12 +279,8 @@ impl Lobby {
} }
Message::Host(HostMessage::Lobby(HostLobbyMessage::Kick(player_id))) Message::Host(HostMessage::Lobby(HostLobbyMessage::Kick(player_id)))
| Message::Client(IdentifiedClientMessage { | Message::Client(IdentifiedClientMessage {
identity: identity: Identification { player_id, .. },
Identification { update: ClientUpdate::Message(ClientMessage::Goodbye),
player_id,
public: _,
},
message: ClientMessage::Goodbye,
}) => { }) => {
if let Some(remove_idx) = self if let Some(remove_idx) = self
.players_in_lobby .players_in_lobby
@ -297,12 +294,8 @@ impl Lobby {
} }
} }
Message::Client(IdentifiedClientMessage { Message::Client(IdentifiedClientMessage {
identity: identity: Identification { player_id, .. },
Identification { update: ClientUpdate::Message(ClientMessage::GetState),
player_id,
public: _,
},
message: ClientMessage::GetState,
}) => { }) => {
let msg = ServerMessage::LobbyInfo { let msg = ServerMessage::LobbyInfo {
joined: self joined: self
@ -320,12 +313,12 @@ impl Lobby {
} }
} }
Message::Client(IdentifiedClientMessage { Message::Client(IdentifiedClientMessage {
identity: _, update: ClientUpdate::Message(ClientMessage::RoleAck),
message: ClientMessage::RoleAck, ..
}) => return Err(GameError::InvalidMessageForGameState), }) => return Err(GameError::InvalidMessageForGameState),
Message::Client(IdentifiedClientMessage { Message::Client(IdentifiedClientMessage {
identity: Identification { player_id, public }, identity: Identification { player_id, public },
message: ClientMessage::UpdateSelf(_), update: ClientUpdate::Message(ClientMessage::UpdateSelf(_)),
}) => { }) => {
self.joined_players self.joined_players
.update(&player_id, |p| { .update(&player_id, |p| {
@ -344,7 +337,10 @@ impl Lobby {
self.send_lobby_info_to_clients().await; self.send_lobby_info_to_clients().await;
self.send_lobby_info_to_host().await.log_debug(); self.send_lobby_info_to_host().await.log_debug();
} }
Message::ConnectedList(_) => self.send_lobby_info_to_host().await?, Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
}) => self.send_lobby_info_to_host().await?,
Message::Host(HostMessage::Echo(msg)) => { Message::Host(HostMessage::Echo(msg)) => {
self.comms()?.host().send(msg).log_warn(); self.comms()?.host().send(msg).log_warn();
} }

View File

@ -37,7 +37,7 @@ use std::{env, io::Write, path::Path};
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use crate::{ use crate::{
communication::{Comms, connect::ConnectUpdate, host::HostComms, player::PlayerIdComms}, communication::{Comms, host::HostComms, player::PlayerIdComms},
saver::FileSaver, saver::FileSaver,
}; };
@ -105,18 +105,14 @@ async fn main() {
let listen_addr = let listen_addr =
SocketAddr::from_str(format!("{host}:{port}").as_str()).expect("invalid host/port"); SocketAddr::from_str(format!("{host}:{port}").as_str()).expect("invalid host/port");
let (send, recv) = broadcast::channel(100); let (send, recv) = broadcast::channel(1000);
let (server_send, host_recv) = broadcast::channel(100); let (server_send, host_recv) = broadcast::channel(1000);
let (host_send, server_recv) = mpsc::channel(100); let (host_send, server_recv) = mpsc::channel(1000);
let conn_update = ConnectUpdate::new(); let joined_players = JoinedPlayers::new();
let joined_players = JoinedPlayers::new(conn_update.clone()); let lobby_comms = LobbyComms::new(Comms::new(
let lobby_comms = LobbyComms::new( HostComms::new(server_send, server_recv),
Comms::new( PlayerIdComms::new(recv),
HostComms::new(server_send, server_recv), ));
PlayerIdComms::new(recv),
),
conn_update,
);
let jp_clone = joined_players.clone(); let jp_clone = joined_players.clone();

View File

@ -29,10 +29,16 @@ use crate::{
saver::Saver, saver::Saver,
}; };
#[derive(Debug, Clone, PartialEq)]
pub enum ClientUpdate {
ConnectStateUpdate,
Message(ClientMessage),
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct IdentifiedClientMessage { pub struct IdentifiedClientMessage {
pub identity: Identification, pub identity: Identification,
pub message: ClientMessage, pub update: ClientUpdate,
} }
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut saver: impl Saver) { pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut saver: impl Saver) {
@ -100,7 +106,6 @@ pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut save
pub enum Message { pub enum Message {
Host(HostMessage), Host(HostMessage),
Client(IdentifiedClientMessage), Client(IdentifiedClientMessage),
ConnectedList(Arc<[PlayerId]>),
} }
pub enum RunningState { pub enum RunningState {