From 36348285317f6e073581479821564ddf825777c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cel=20=F0=9F=8C=B8?= Date: Tue, 11 Feb 2025 10:54:16 +0000 Subject: [PATCH] add iq hashmap for iq requests --- luz/src/connection/mod.rs | 16 ++++++++++++---- luz/src/connection/read.rs | 26 +++++++++++++++++++++++--- luz/src/lib.rs | 9 +++++++-- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/luz/src/connection/mod.rs b/luz/src/connection/mod.rs index 85cf7cc..f8cf18b 100644 --- a/luz/src/connection/mod.rs +++ b/luz/src/connection/mod.rs @@ -1,6 +1,7 @@ // TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly use std::{ + collections::HashMap, ops::{Deref, DerefMut}, sync::Arc, time::Duration, @@ -10,6 +11,7 @@ use jabber::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream}; use jid::JID; use read::{ReadControl, ReadControlHandle}; use sqlx::SqlitePool; +use stanza::client::Stanza; use tokio::{ sync::{mpsc, oneshot, Mutex}, task::{JoinHandle, JoinSet}, @@ -30,6 +32,7 @@ pub struct Supervisor { tokio::task::JoinSet<()>, mpsc::Sender, WriteHandle, + Arc>>>>, )>, sender: mpsc::Sender, writer_handle: WriteControlHandle, @@ -56,6 +59,7 @@ impl Supervisor { JoinSet<()>, mpsc::Sender, WriteHandle, + Arc>>>>, )>, sender: mpsc::Sender, writer_handle: WriteControlHandle, @@ -108,7 +112,7 @@ impl Supervisor { // consider awaiting/aborting the read and write threads let (send, recv) = oneshot::channel(); let _ = self.reader_handle.send(ReadControl::Abort(send)).await; - let (db, update_sender, tasks, supervisor_command, write_sender) = tokio::select! { + let (db, update_sender, tasks, supervisor_command, write_sender, pending_iqs) = tokio::select! { Ok(s) = recv => s, Ok(s) = &mut self.reader_crash => s, // in case, just break as irrecoverable @@ -134,7 +138,8 @@ impl Supervisor { update_sender, supervisor_command, write_sender, - tasks + tasks, + pending_iqs, ); }, Err(e) => { @@ -149,7 +154,7 @@ impl Supervisor { }, } }, - Ok((db, update_sender, tasks, supervisor_control, write_handle)) = &mut self.reader_crash => { + Ok((db, update_sender, tasks, supervisor_control, write_handle, pending_iqs)) = &mut self.reader_crash => { let (send, recv) = oneshot::channel(); let _ = self.writer_handle.send(WriteControl::Abort(send)).await; let (retry_msg, mut write_receiver) = tokio::select! { @@ -182,7 +187,8 @@ impl Supervisor { update_sender, supervisor_control, write_handle, - tasks + tasks, + pending_iqs, ); }, Err(e) => { @@ -252,6 +258,7 @@ impl SupervisorHandle { on_shutdown: oneshot::Sender<()>, jid: Arc>, password: Arc, + pending_iqs: Arc>>>>, ) -> (WriteHandle, Self) { let (command_sender, command_receiver) = mpsc::channel(20); let (writer_error_sender, writer_error_receiver) = oneshot::channel(); @@ -267,6 +274,7 @@ impl SupervisorHandle { update_sender.clone(), command_sender.clone(), write_handle.clone(), + pending_iqs, ); let actor = Supervisor::new( diff --git a/luz/src/connection/read.rs b/luz/src/connection/read.rs index edc6cdb..c1e37b4 100644 --- a/luz/src/connection/read.rs +++ b/luz/src/connection/read.rs @@ -1,5 +1,7 @@ use std::{ + collections::HashMap, ops::{Deref, DerefMut}, + sync::Arc, time::Duration, }; @@ -7,7 +9,7 @@ use jabber::{connection::Tls, jabber_stream::bound_stream::BoundJabberReader}; use sqlx::SqlitePool; use stanza::client::Stanza; use tokio::{ - sync::{mpsc, oneshot}, + sync::{mpsc, oneshot, Mutex}, task::{JoinHandle, JoinSet}, }; @@ -28,6 +30,7 @@ pub struct Read { JoinSet<()>, mpsc::Sender, WriteHandle, + Arc>>>>, )>, db: SqlitePool, update_sender: mpsc::Sender, @@ -36,6 +39,8 @@ pub struct Read { tasks: JoinSet<()>, disconnecting: bool, disconnect_timedout: oneshot::Receiver<()>, + // TODO: use proper stanza ids + pending_iqs: Arc>>>>, } impl Read { @@ -48,6 +53,7 @@ impl Read { JoinSet<()>, mpsc::Sender, WriteHandle, + Arc>>>>, )>, db: SqlitePool, update_sender: mpsc::Sender, @@ -55,6 +61,7 @@ impl Read { supervisor_control: mpsc::Sender, write_handle: WriteHandle, tasks: JoinSet<()>, + pending_iqs: Arc>>>>, ) -> Self { let (send, recv) = oneshot::channel(); Self { @@ -68,6 +75,7 @@ impl Read { tasks, disconnecting: false, disconnect_timedout: recv, + pending_iqs, } } @@ -91,7 +99,7 @@ impl Read { }) }, ReadControl::Abort(sender) => { - let _ = sender.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle)); + let _ = sender.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle, self.pending_iqs)); break; }, }; @@ -126,7 +134,7 @@ impl Read { break; } else { // AAAAAAAAAAAAAAAAAAAAA i should really just have this stored in the supervisor and not gaf bout passing these references around - let _ = self.on_crash.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle)); + let _ = self.on_crash.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle, self.pending_iqs)); } break; }, @@ -134,6 +142,11 @@ impl Read { }, else => break } + // when it aborts, must clear iq map no matter what + let mut iqs = self.pending_iqs.lock().await; + for (_id, sender) in iqs.drain() { + let _ = sender.send(Err(Error::LostConnection)); + } } } } @@ -162,6 +175,7 @@ pub enum ReadControl { JoinSet<()>, mpsc::Sender, WriteHandle, + Arc>>>>, )>, ), } @@ -194,11 +208,13 @@ impl ReadControlHandle { JoinSet<()>, mpsc::Sender, WriteHandle, + Arc>>>>, )>, db: SqlitePool, sender: mpsc::Sender, supervisor_control: mpsc::Sender, jabber_write: WriteHandle, + pending_iqs: Arc>>>>, ) -> Self { let (control_sender, control_receiver) = mpsc::channel(20); @@ -211,6 +227,7 @@ impl ReadControlHandle { supervisor_control, jabber_write, JoinSet::new(), + pending_iqs, ); let handle = tokio::spawn(async move { actor.run().await }); @@ -228,12 +245,14 @@ impl ReadControlHandle { JoinSet<()>, mpsc::Sender, WriteHandle, + Arc>>>>, )>, db: SqlitePool, sender: mpsc::Sender, supervisor_control: mpsc::Sender, jabber_write: WriteHandle, tasks: JoinSet<()>, + pending_iqs: Arc>>>>, ) -> Self { let (control_sender, control_receiver) = mpsc::channel(20); @@ -246,6 +265,7 @@ impl ReadControlHandle { supervisor_control, jabber_write, tasks, + pending_iqs, ); let handle = tokio::spawn(async move { actor.run().await }); diff --git a/luz/src/lib.rs b/luz/src/lib.rs index 333d8eb..9d8ea66 100644 --- a/luz/src/lib.rs +++ b/luz/src/lib.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use connection::SupervisorSender; use jabber::JID; use sqlx::SqlitePool; -use stanza::roster; +use stanza::{client::Stanza, roster}; use tokio::{ sync::{mpsc, oneshot, Mutex}, task::JoinSet, @@ -22,6 +22,7 @@ pub struct Luz { // TODO: use a dyn passwordprovider trait to avoid storing password in memory password: Arc, connected: Arc>>, + pending_iqs: Arc>>>>, db: SqlitePool, sender: mpsc::Sender, /// if connection was shut down due to e.g. server shutdown, supervisor must be able to mark client as disconnected @@ -50,6 +51,7 @@ impl Luz { sender, tasks: JoinSet::new(), connection_supervisor_shutdown, + pending_iqs: Arc::new(Mutex::new(HashMap::new())), } } @@ -87,6 +89,7 @@ impl Luz { shutdown_send, self.jid.clone(), self.password.clone(), + self.pending_iqs.clone(), ); self.connection_supervisor_shutdown = shutdown_recv; *self.connected.lock().await = Some((writer, supervisor)); @@ -121,6 +124,7 @@ impl Luz { self.db.clone(), self.sender.clone(), // TODO: iq hashmap + self.pending_iqs.clone() )), None => self.tasks.spawn(msg.handle_offline( self.jid.clone(), @@ -155,6 +159,7 @@ impl CommandMessage { jid: Arc>, db: SqlitePool, sender: mpsc::Sender, + pending_iqs: Arc>>>>, ) { todo!() }