fix: no longer self-wake future checking for connects

This commit is contained in:
emilis 2025-12-09 20:50:08 +00:00
parent 9b989389b5
commit f6cf92bc40
No known key found for this signature in database
11 changed files with 350 additions and 593 deletions

667
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,18 +5,16 @@ edition = "2024"
[dependencies]
axum = { version = "0.8", features = ["ws"] }
tokio = { version = "1.47", features = ["full"] }
tokio = { version = "1.48", features = ["full"] }
log = { version = "0.4" }
pretty_env_logger = { version = "0.5" }
# env_logger = { version = "0.11" }
futures = "0.3.31"
anyhow = { version = "1" }
werewolves-proto = { path = "../werewolves-proto" }
werewolves-macros = { path = "../werewolves-macros" }
mime-sniffer = { version = "0.1" }
chrono = { version = "0.4" }
atom_syndication = { version = "0.12" }
axum-extra = { version = "0.10", features = ["typed-header"] }
axum-extra = { version = "0.12", features = ["typed-header"] }
rand = { version = "0.9" }
serde_json = { version = "1.0" }
serde = { version = "1.0", features = ["derive"] }

View File

@ -15,9 +15,9 @@
use core::{net::SocketAddr, time::Duration};
use crate::{
AppState, XForwardedFor,
AppState, LogError, XForwardedFor,
connection::{ConnectionId, JoinedPlayer},
runner::IdentifiedClientMessage,
runner::{ClientUpdate, IdentifiedClientMessage},
};
use axum::{
extract::{
@ -73,20 +73,33 @@ pub async fn handler(
)
.await
};
state
.send
.send(IdentifiedClientMessage {
identity: ident.clone(),
update: ClientUpdate::ConnectStateUpdate,
})
.log_debug();
Client::new(
ident.clone(),
connection_id.clone(),
socket,
who.to_string(),
state.send,
state.send.clone(),
recv,
)
.run()
.await;
// log::debug!("ending connection with {who}");
player_list.disconnect(&connection_id).await;
state
.send
.send(IdentifiedClientMessage {
identity: ident.clone(),
update: ClientUpdate::ConnectStateUpdate,
})
.log_debug();
})
}
@ -220,7 +233,7 @@ impl Client {
}
self.sender.send(IdentifiedClientMessage {
message,
update: ClientUpdate::Message(message),
identity: self.ident.clone(),
})?;

View File

@ -1,74 +0,0 @@
// Copyright (C) 2025 Emilis Bliūdžius
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use core::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Mutex;
use werewolves_proto::player::PlayerId;
#[derive(Debug, Clone)]
pub struct ConnectUpdate {
updated: Arc<AtomicBool>,
connected: Arc<Mutex<Vec<PlayerId>>>,
}
impl ConnectUpdate {
pub fn new() -> Self {
Self {
updated: Arc::new(AtomicBool::new(false)),
connected: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn connect(&self, pid: PlayerId) {
let mut connected = self.connected.lock().await;
if connected.iter().any(|c| c == &pid) {
return;
}
connected.push(pid);
self.updated.store(true, Ordering::SeqCst);
}
pub async fn disconnect(&self, pid: PlayerId) {
let mut connected = self.connected.lock().await;
if let Some(idx) = connected
.iter()
.enumerate()
.find_map(|(idx, c)| (c == &pid).then_some(idx))
{
connected.swap_remove(idx);
self.updated.store(true, Ordering::SeqCst);
}
}
}
impl Future for ConnectUpdate {
type Output = Arc<[PlayerId]>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if self.updated.load(Ordering::SeqCst)
&& let Ok(connected) = self.connected.try_lock()
{
self.updated.store(false, Ordering::SeqCst);
std::task::Poll::Ready(connected.clone().into())
} else {
cx.waker().wake_by_ref();
std::task::Poll::Pending
}
}
}

View File

@ -14,29 +14,21 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use werewolves_proto::error::GameError;
use crate::{
communication::{Comms, connect::ConnectUpdate},
runner::Message,
};
use crate::{communication::Comms, runner::Message};
use super::{HostComms, player::PlayerIdComms};
pub struct LobbyComms {
comms: Comms,
// TODO: move this to not use a receiver
connect_recv: ConnectUpdate,
}
impl LobbyComms {
pub fn new(comms: Comms, connect_recv: ConnectUpdate) -> Self {
Self {
comms,
connect_recv,
}
pub fn new(comms: Comms) -> Self {
Self { comms }
}
pub fn into_inner(self) -> (Comms, ConnectUpdate) {
(self.comms, self.connect_recv)
pub fn into_inner(self) -> Comms {
self.comms
}
#[allow(unused)]
@ -49,17 +41,12 @@ impl LobbyComms {
}
pub async fn next_message(&mut self) -> Result<Message, GameError> {
tokio::select! {
r = self.comms.message() => {
match r {
Ok(val) => Ok(val),
Err(GameError::GenericError(err)) => Err(GameError::GenericError(format!("comms message: {err}"))),
Err(err) => Err(err),
}
}
r = self.connect_recv.clone() => {
Ok(Message::ConnectedList(r))
}
match self.comms.message().await {
Ok(val) => Ok(val),
Err(GameError::GenericError(err)) => {
Err(GameError::GenericError(format!("comms message: {err}")))
}
Err(err) => Err(err),
}
}
}

View File

@ -19,7 +19,6 @@ use crate::{
runner::Message,
};
pub mod connect;
pub mod host;
pub mod lobby;
pub mod player;

View File

@ -27,7 +27,7 @@ use werewolves_proto::{
player::PlayerId,
};
use crate::{LogError, communication::connect::ConnectUpdate};
use crate::LogError;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConnectionId(PlayerId, Instant);
@ -80,13 +80,11 @@ impl JoinedPlayer {
#[derive(Debug, Clone)]
pub struct JoinedPlayers {
players: Arc<Mutex<HashMap<PlayerId, JoinedPlayer>>>,
connect_state: ConnectUpdate,
}
impl JoinedPlayers {
pub fn new(connect_state: ConnectUpdate) -> Self {
pub fn new() -> Self {
Self {
connect_state,
players: Arc::new(Mutex::new(HashMap::new())),
}
}
@ -172,8 +170,6 @@ impl JoinedPlayers {
pub async fn disconnect(&self, connection: &ConnectionId) -> Option<JoinedPlayer> {
let mut map = self.players.lock().await;
self.connect_state.disconnect(connection.0).await;
if map
.get(connection.player_id())
.map(|p| p.active_connection == *connection && !p.in_game)
@ -206,8 +202,6 @@ impl JoinedPlayers {
old.receiver
} else {
self.connect_state.connect(player_id).await;
unsafe { map.get(&player_id).unwrap_unchecked() }
.receiver
.resubscribe()

View File

@ -13,14 +13,13 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use core::{ops::Not, time::Duration};
use std::sync::Arc;
use crate::{
LogError,
communication::{Comms, connect::ConnectUpdate, lobby::LobbyComms},
communication::{Comms, lobby::LobbyComms},
connection::JoinedPlayers,
lobby::{Lobby, LobbyPlayers},
runner::{IdentifiedClientMessage, Message},
runner::{ClientUpdate, IdentifiedClientMessage, Message},
};
use tokio::time::Instant;
use werewolves_proto::{
@ -39,7 +38,6 @@ type Result<T> = core::result::Result<T, GameError>;
pub struct GameRunner {
game: Game,
comms: Comms,
connect_recv: ConnectUpdate,
player_sender: LobbyPlayers,
roles_revealed: bool,
joined_players: JoinedPlayers,
@ -50,13 +48,11 @@ impl GameRunner {
game: Game,
comms: Comms,
player_sender: LobbyPlayers,
connect_recv: ConnectUpdate,
joined_players: JoinedPlayers,
) -> Self {
Self {
game,
comms,
connect_recv,
player_sender,
joined_players,
roles_revealed: false,
@ -68,10 +64,7 @@ impl GameRunner {
}
pub fn into_lobby(self) -> Lobby {
let mut lobby = Lobby::new(
self.joined_players,
LobbyComms::new(self.comms, self.connect_recv),
);
let mut lobby = Lobby::new(self.joined_players, LobbyComms::new(self.comms));
lobby.set_settings(self.game.village().settings());
lobby.set_players_in_lobby(self.player_sender);
lobby
@ -155,7 +148,6 @@ impl GameRunner {
};
let mut last_err_log = tokio::time::Instant::now() - tokio::time::Duration::from_secs(60);
let mut connect_list: Arc<[PlayerId]> = Arc::new([]);
while acks.iter().any(|(_, ackd)| !*ackd) {
const PING_TIME: Duration = Duration::from_secs(1);
let sleep_fut = tokio::time::sleep(PING_TIME);
@ -198,7 +190,7 @@ impl GameRunner {
player_id,
public: _,
},
message: ClientMessage::GetState,
update: ClientUpdate::Message(ClientMessage::GetState),
}) => {
let Some(sender) = self.joined_players.get_sender(player_id).await else {
continue;
@ -232,7 +224,7 @@ impl GameRunner {
player_id,
public: _,
},
message: ClientMessage::RoleAck,
update: ClientUpdate::Message(ClientMessage::RoleAck),
}) => {
if let Some((_, ackd)) =
acks.iter_mut().find(|(t, _)| t.player_id() == player_id)
@ -248,21 +240,9 @@ impl GameRunner {
}
}
Message::Client(IdentifiedClientMessage {
identity:
Identification {
player_id,
public: _,
},
message: _,
identity: Identification { player_id, .. },
..
}) => (notify_of_role)(player_id, self.game.village(), &self.joined_players).await,
Message::ConnectedList(c) => {
let newly_connected = c.iter().filter(|c| connect_list.contains(*c));
for connected in newly_connected {
(notify_of_role)(*connected, self.game.village(), &self.joined_players)
.await
}
connect_list = c;
}
}
}
@ -277,7 +257,10 @@ impl GameRunner {
pub async fn next(&mut self) -> Option<GameOver> {
let msg = match self.comms.message().await {
Ok(Message::ConnectedList(_)) => return None,
Ok(Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
})) => return None,
Ok(Message::Client(IdentifiedClientMessage {
identity: Identification { player_id, .. },
..
@ -459,6 +442,10 @@ impl GameEnd {
GameError::InvalidMessageForGameState,
))
.log_debug(),
Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
}) => {}
Message::Client(IdentifiedClientMessage { identity, .. }) => {
let story = self.game().ok()?.game.story();
return Some(ProcessOutcome::SendPlayer(
@ -466,7 +453,6 @@ impl GameEnd {
ServerMessage::Story(story),
));
}
Message::ConnectedList(_) => {}
}
None
}

View File

@ -33,7 +33,7 @@ use crate::{
communication::lobby::LobbyComms,
connection::JoinedPlayers,
game::GameRunner,
runner::{IdentifiedClientMessage, Message},
runner::{ClientUpdate, IdentifiedClientMessage, Message},
};
pub struct Lobby {
@ -147,12 +147,8 @@ impl Lobby {
.log_warn(),
Err((
Message::Client(IdentifiedClientMessage {
identity:
Identification {
player_id,
public: _,
},
message: _,
identity: Identification { player_id, .. },
..
}),
GameError::InvalidMessageForGameState,
)) => {
@ -160,10 +156,17 @@ impl Lobby {
.players_in_lobby
.send_if_present(player_id, ServerMessage::InvalidMessageForGameState);
}
Err((
Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
}),
_,
)) => {}
Err((
Message::Client(IdentifiedClientMessage {
identity: Identification { player_id, public },
message: _,
..
}),
err,
)) => {
@ -172,7 +175,6 @@ impl Lobby {
.players_in_lobby
.send_if_present(player_id, ServerMessage::Reset);
}
Err((Message::ConnectedList(_), _)) => {}
}
None
}
@ -249,18 +251,17 @@ impl Lobby {
let game = Game::new(&playing_players, self.settings.clone())?;
assert_eq!(game.village().characters().len(), playing_players.len());
let (comms, recv) = self.comms.take().unwrap().into_inner();
let comms = self.comms.take().unwrap().into_inner();
return Ok(Some(GameRunner::new(
game,
comms,
self.players_in_lobby.clone(),
recv,
self.joined_players.clone(),
)));
}
Message::Client(IdentifiedClientMessage {
identity,
message: ClientMessage::Hello,
update: ClientUpdate::Message(ClientMessage::Hello),
}) => {
if self
.players_in_lobby
@ -278,12 +279,8 @@ impl Lobby {
}
Message::Host(HostMessage::Lobby(HostLobbyMessage::Kick(player_id)))
| Message::Client(IdentifiedClientMessage {
identity:
Identification {
player_id,
public: _,
},
message: ClientMessage::Goodbye,
identity: Identification { player_id, .. },
update: ClientUpdate::Message(ClientMessage::Goodbye),
}) => {
if let Some(remove_idx) = self
.players_in_lobby
@ -297,12 +294,8 @@ impl Lobby {
}
}
Message::Client(IdentifiedClientMessage {
identity:
Identification {
player_id,
public: _,
},
message: ClientMessage::GetState,
identity: Identification { player_id, .. },
update: ClientUpdate::Message(ClientMessage::GetState),
}) => {
let msg = ServerMessage::LobbyInfo {
joined: self
@ -320,12 +313,12 @@ impl Lobby {
}
}
Message::Client(IdentifiedClientMessage {
identity: _,
message: ClientMessage::RoleAck,
update: ClientUpdate::Message(ClientMessage::RoleAck),
..
}) => return Err(GameError::InvalidMessageForGameState),
Message::Client(IdentifiedClientMessage {
identity: Identification { player_id, public },
message: ClientMessage::UpdateSelf(_),
update: ClientUpdate::Message(ClientMessage::UpdateSelf(_)),
}) => {
self.joined_players
.update(&player_id, |p| {
@ -344,7 +337,10 @@ impl Lobby {
self.send_lobby_info_to_clients().await;
self.send_lobby_info_to_host().await.log_debug();
}
Message::ConnectedList(_) => self.send_lobby_info_to_host().await?,
Message::Client(IdentifiedClientMessage {
update: ClientUpdate::ConnectStateUpdate,
..
}) => self.send_lobby_info_to_host().await?,
Message::Host(HostMessage::Echo(msg)) => {
self.comms()?.host().send(msg).log_warn();
}

View File

@ -37,7 +37,7 @@ use std::{env, io::Write, path::Path};
use tokio::sync::{broadcast, mpsc};
use crate::{
communication::{Comms, connect::ConnectUpdate, host::HostComms, player::PlayerIdComms},
communication::{Comms, host::HostComms, player::PlayerIdComms},
saver::FileSaver,
};
@ -105,18 +105,14 @@ async fn main() {
let listen_addr =
SocketAddr::from_str(format!("{host}:{port}").as_str()).expect("invalid host/port");
let (send, recv) = broadcast::channel(100);
let (server_send, host_recv) = broadcast::channel(100);
let (host_send, server_recv) = mpsc::channel(100);
let conn_update = ConnectUpdate::new();
let joined_players = JoinedPlayers::new(conn_update.clone());
let lobby_comms = LobbyComms::new(
Comms::new(
HostComms::new(server_send, server_recv),
PlayerIdComms::new(recv),
),
conn_update,
);
let (send, recv) = broadcast::channel(1000);
let (server_send, host_recv) = broadcast::channel(1000);
let (host_send, server_recv) = mpsc::channel(1000);
let joined_players = JoinedPlayers::new();
let lobby_comms = LobbyComms::new(Comms::new(
HostComms::new(server_send, server_recv),
PlayerIdComms::new(recv),
));
let jp_clone = joined_players.clone();

View File

@ -29,10 +29,16 @@ use crate::{
saver::Saver,
};
#[derive(Debug, Clone, PartialEq)]
pub enum ClientUpdate {
ConnectStateUpdate,
Message(ClientMessage),
}
#[derive(Debug, Clone, PartialEq)]
pub struct IdentifiedClientMessage {
pub identity: Identification,
pub message: ClientMessage,
pub update: ClientUpdate,
}
pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut saver: impl Saver) {
@ -100,7 +106,6 @@ pub async fn run_game(joined_players: JoinedPlayers, comms: LobbyComms, mut save
pub enum Message {
Host(HostMessage),
Client(IdentifiedClientMessage),
ConnectedList(Arc<[PlayerId]>),
}
pub enum RunningState {