diff --git a/luz/src/connection/mod.rs b/luz/src/connection/mod.rs index 3ad2648..85cf7cc 100644 --- a/luz/src/connection/mod.rs +++ b/luz/src/connection/mod.rs @@ -1,15 +1,22 @@ -use std::ops::{Deref, DerefMut}; +// TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly + +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, + time::Duration, +}; use jabber::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream}; -use read::ReadControlHandle; +use jid::JID; +use read::{ReadControl, ReadControlHandle}; use sqlx::SqlitePool; use tokio::{ - sync::{mpsc, oneshot}, + sync::{mpsc, oneshot, Mutex}, task::{JoinHandle, JoinSet}, }; -use write::{WriteControlHandle, WriteHandle, WriteMessage}; +use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage}; -use crate::UpdateMessage; +use crate::{error::Error, UpdateMessage}; mod read; pub(crate) mod write; @@ -21,16 +28,21 @@ pub struct Supervisor { SqlitePool, mpsc::Sender, tokio::task::JoinSet<()>, + mpsc::Sender, + WriteHandle, )>, sender: mpsc::Sender, writer_handle: WriteControlHandle, reader_handle: ReadControlHandle, on_shutdown: oneshot::Sender<()>, + jid: Arc>, + password: Arc, } pub enum SupervisorCommand { Disconnect, // for if there was a stream error, require to reconnect + // couldn't stream errors just cause a crash? lol Reconnect, } @@ -38,11 +50,19 @@ impl Supervisor { fn new( connection_commands: mpsc::Receiver, writer_crash: oneshot::Receiver<(WriteMessage, mpsc::Receiver)>, - reader_crash: oneshot::Receiver<(SqlitePool, mpsc::Sender, JoinSet<()>)>, + reader_crash: oneshot::Receiver<( + SqlitePool, + mpsc::Sender, + JoinSet<()>, + mpsc::Sender, + WriteHandle, + )>, sender: mpsc::Sender, writer_handle: WriteControlHandle, reader_handle: ReadControlHandle, on_shutdown: oneshot::Sender<()>, + jid: Arc>, + password: Arc, ) -> Self { Self { connection_commands, @@ -52,27 +72,137 @@ impl Supervisor { reader_handle, reader_crash, on_shutdown, + jid, + password, } } - async fn handle_command_message(&mut self, msg: SupervisorCommand) {} - async fn run(mut self) { loop { tokio::select! { Some(msg) = self.connection_commands.recv() => { - self.handle_command_message(msg).await; + match msg { + SupervisorCommand::Disconnect => { + let _ = self.writer_handle.send(WriteControl::Disconnect).await; + let _ = self.reader_handle.send(ReadControl::Disconnect).await; + tokio::select! { + _ = async { tokio::join!( + async { let _ = (&mut self.writer_handle.handle).await; }, + async { let _ = (&mut self.reader_handle.handle).await; } + ) } => {}, + _ = async { tokio::time::sleep(Duration::from_secs(5)) } => { + (&mut self.reader_handle.handle).abort(); + (&mut self.writer_handle.handle).abort(); + } + } + break; + }, + SupervisorCommand::Reconnect => { + // TODO: please omfg + // send abort to read stream, as already done, consider + todo!() + }, + } }, - error = &mut self.writer_crash => { + Ok((write_msg, mut write_recv)) = &mut self.writer_crash => { + // consider awaiting/aborting the read and write threads + let (send, recv) = oneshot::channel(); + let _ = self.reader_handle.send(ReadControl::Abort(send)).await; + let (db, update_sender, tasks, supervisor_command, write_sender) = tokio::select! { + Ok(s) = recv => s, + Ok(s) = &mut self.reader_crash => s, + // in case, just break as irrecoverable + else => break, + }; + let mut jid = self.jid.lock().await; + let mut domain = jid.domainpart.clone(); + let connection = jabber::connect_and_login(&mut jid, &*self.password, &mut domain).await; + match connection { + Ok(c) => { + let (read, write) = c.split(); + let (send, recv) = oneshot::channel(); + self.writer_crash = recv; + self.writer_handle = + WriteControlHandle::reconnect_retry(write, send, write_msg, write_recv); + let (send, recv) = oneshot::channel(); + self.reader_crash = recv; + self.reader_handle = ReadControlHandle::reconnect( + read, + send, + db, + update_sender, + supervisor_command, + write_sender, + tasks + ); + }, + Err(e) => { + // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. + write_recv.close(); + let _ = write_msg.respond_to.send(Err(Error::LostConnection)); + while let Some(msg) = write_recv.recv().await { + let _ = msg.respond_to.send(Err(Error::LostConnection)); + } + let _ = self.sender.send(UpdateMessage::Error(e.into())).await; + break; + }, + } }, - error = &mut self.reader_crash => { + Ok((db, update_sender, tasks, supervisor_control, write_handle)) = &mut self.reader_crash => { + let (send, recv) = oneshot::channel(); + let _ = self.writer_handle.send(WriteControl::Abort(send)).await; + let (retry_msg, mut write_receiver) = tokio::select! { + Ok(s) = recv => (None, s), + Ok(s) = &mut self.writer_crash => (Some(s.0), s.1), + // in case, just break as irrecoverable + else => break, + }; + let mut jid = self.jid.lock().await; + let mut domain = jid.domainpart.clone(); + let connection = jabber::connect_and_login(&mut jid, &*self.password, &mut domain).await; + match connection { + Ok(c) => { + let (read, write) = c.split(); + let (send, recv) = oneshot::channel(); + self.writer_crash = recv; + if let Some(msg) = retry_msg { + self.writer_handle = + WriteControlHandle::reconnect_retry(write, send, msg, write_receiver); + } else { + self.writer_handle = WriteControlHandle::reconnect(write, send, write_receiver) + } + let (send, recv) = oneshot::channel(); + self.reader_crash = recv; + self.reader_handle = ReadControlHandle::reconnect( + read, + send, + db, + update_sender, + supervisor_control, + write_handle, + tasks + ); + }, + Err(e) => { + // if reconnection failure, respond to all current messages with lost connection error. + write_receiver.close(); + if let Some(msg) = retry_msg { + msg.respond_to.send(Err(Error::LostConnection)); + } + while let Some(msg) = write_receiver.recv().await { + msg.respond_to.send(Err(Error::LostConnection)); + } + let _ = self.sender.send(UpdateMessage::Error(e.into())).await; + break; + }, + } }, else => break, } } - self.on_shutdown.send(()); + let _ = self.on_shutdown.send(()); } } @@ -120,6 +250,8 @@ impl SupervisorHandle { update_sender: mpsc::Sender, db: SqlitePool, on_shutdown: oneshot::Sender<()>, + jid: Arc>, + password: Arc, ) -> (WriteHandle, Self) { let (command_sender, command_receiver) = mpsc::channel(20); let (writer_error_sender, writer_error_receiver) = oneshot::channel(); @@ -145,6 +277,8 @@ impl SupervisorHandle { write_control_handle, jabber_reader_control_handle, on_shutdown, + jid, + password, ); let handle = tokio::spawn(async move { actor.run().await }); diff --git a/luz/src/connection/read.rs b/luz/src/connection/read.rs index 7800d56..edc6cdb 100644 --- a/luz/src/connection/read.rs +++ b/luz/src/connection/read.rs @@ -1,3 +1,8 @@ +use std::{ + ops::{Deref, DerefMut}, + time::Duration, +}; + use jabber::{connection::Tls, jabber_stream::bound_stream::BoundJabberReader}; use sqlx::SqlitePool; use stanza::client::Stanza; @@ -6,7 +11,7 @@ use tokio::{ task::{JoinHandle, JoinSet}, }; -use crate::UpdateMessage; +use crate::{error::Error, UpdateMessage}; use super::{ write::{WriteHandle, WriteMessage}, @@ -17,25 +22,41 @@ pub struct Read { // TODO: place iq hashmap here control_receiver: mpsc::Receiver, stream: BoundJabberReader, - on_crash: oneshot::Sender<(SqlitePool, mpsc::Sender, JoinSet<()>)>, + on_crash: oneshot::Sender<( + SqlitePool, + mpsc::Sender, + JoinSet<()>, + mpsc::Sender, + WriteHandle, + )>, db: SqlitePool, update_sender: mpsc::Sender, supervisor_control: mpsc::Sender, write_handle: WriteHandle, tasks: JoinSet<()>, + disconnecting: bool, + disconnect_timedout: oneshot::Receiver<()>, } impl Read { fn new( control_receiver: mpsc::Receiver, stream: BoundJabberReader, - on_crash: oneshot::Sender<(SqlitePool, mpsc::Sender, JoinSet<()>)>, + on_crash: oneshot::Sender<( + SqlitePool, + mpsc::Sender, + JoinSet<()>, + mpsc::Sender, + WriteHandle, + )>, db: SqlitePool, update_sender: mpsc::Sender, // jabber server must be able to both terminate the connection from error, and ask for data from the client (such as supported XEPs) supervisor_control: mpsc::Sender, - write_sender: WriteHandle, + write_handle: WriteHandle, + tasks: JoinSet<()>, ) -> Self { + let (send, recv) = oneshot::channel(); Self { control_receiver, stream, @@ -43,26 +64,73 @@ impl Read { db, update_sender, supervisor_control, - write_handle: write_sender, - tasks: JoinSet::new(), + write_handle, + tasks, + disconnecting: false, + disconnect_timedout: recv, } } async fn run(mut self) { loop { tokio::select! { + // if still haven't received the end tag in time, just kill itself + _ = &mut self.disconnect_timedout => { + break; + } Some(msg) = self.control_receiver.recv() => { match msg { - ReadControl::Disconnect => todo!(), - ReadControl::Abort(sender) => todo!(), + // when disconnect received, + ReadControl::Disconnect => { + let (send, recv) = oneshot::channel(); + self.disconnect_timedout = recv; + self.disconnecting = true; + tokio::spawn(async { + tokio::time::sleep(Duration::from_secs(10)).await; + let _ = send.send(()); + }) + }, + ReadControl::Abort(sender) => { + let _ = sender.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle)); + break; + }, }; }, stanza = self.stream.read::() => { match stanza { - Ok(_) => todo!(), - Err(_) => todo!(), + Ok(s) => { + self.tasks.spawn(handle_stanza(s, self.update_sender.clone(), self.db.clone(), self.supervisor_control.clone(), self.write_handle.clone())); + }, + Err(e) => { + // TODO: NEXT write the correct error stanza depending on error, decide whether to reconnect or properly disconnect, depending on if disconnecting is true + // match e { + // peanuts::Error::ReadError(error) => todo!(), + // peanuts::Error::Utf8Error(utf8_error) => todo!(), + // peanuts::Error::ParseError(_) => todo!(), + // peanuts::Error::EntityProcessError(_) => todo!(), + // peanuts::Error::InvalidCharRef(_) => todo!(), + // peanuts::Error::DuplicateNameSpaceDeclaration(namespace_declaration) => todo!(), + // peanuts::Error::DuplicateAttribute(_) => todo!(), + // peanuts::Error::UnqualifiedNamespace(_) => todo!(), + // peanuts::Error::MismatchedEndTag(name, name1) => todo!(), + // peanuts::Error::NotInElement(_) => todo!(), + // peanuts::Error::ExtraData(_) => todo!(), + // peanuts::Error::UndeclaredNamespace(_) => todo!(), + // peanuts::Error::IncorrectName(name) => todo!(), + // peanuts::Error::DeserializeError(_) => todo!(), + // peanuts::Error::Deserialize(deserialize_error) => todo!(), + // peanuts::Error::RootElementEnded => todo!(), + // } + // TODO: make sure this only happens when an end tag is received + if self.disconnecting == true { + break; + } else { + // AAAAAAAAAAAAAAAAAAAAA i should really just have this stored in the supervisor and not gaf bout passing these references around + let _ = self.on_crash.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle)); + } + break; + }, } - self.tasks.spawn(); }, else => break } @@ -70,30 +138,63 @@ impl Read { } } -trait Task { - async fn handle(); +// what do stanza processes do? +// - update ui +// - access database +// - disconnect proper, reconnect +// - respond to server requests +async fn handle_stanza( + stanza: Stanza, + update_sender: mpsc::Sender, + db: SqlitePool, + supervisor_control: mpsc::Sender, + write_handle: WriteHandle, +) { + todo!() } -impl Task for Stanza { - async fn handle() { - todo!() - } -} - -enum ReadControl { +pub enum ReadControl { Disconnect, - Abort(oneshot::Sender>), + Abort( + oneshot::Sender<( + SqlitePool, + mpsc::Sender, + JoinSet<()>, + mpsc::Sender, + WriteHandle, + )>, + ), } pub struct ReadControlHandle { sender: mpsc::Sender, - handle: JoinHandle<()>, + pub(crate) handle: JoinHandle<()>, +} + +impl Deref for ReadControlHandle { + type Target = mpsc::Sender; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} + +impl DerefMut for ReadControlHandle { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sender + } } impl ReadControlHandle { pub fn new( stream: BoundJabberReader, - on_crash: oneshot::Sender<(SqlitePool, mpsc::Sender, JoinSet<()>)>, + on_crash: oneshot::Sender<( + SqlitePool, + mpsc::Sender, + JoinSet<()>, + mpsc::Sender, + WriteHandle, + )>, db: SqlitePool, sender: mpsc::Sender, supervisor_control: mpsc::Sender, @@ -109,6 +210,42 @@ impl ReadControlHandle { sender, supervisor_control, jabber_write, + JoinSet::new(), + ); + let handle = tokio::spawn(async move { actor.run().await }); + + Self { + sender: control_sender, + handle, + } + } + + pub fn reconnect( + stream: BoundJabberReader, + on_crash: oneshot::Sender<( + SqlitePool, + mpsc::Sender, + JoinSet<()>, + mpsc::Sender, + WriteHandle, + )>, + db: SqlitePool, + sender: mpsc::Sender, + supervisor_control: mpsc::Sender, + jabber_write: WriteHandle, + tasks: JoinSet<()>, + ) -> Self { + let (control_sender, control_receiver) = mpsc::channel(20); + + let actor = Read::new( + control_receiver, + stream, + on_crash, + db, + sender, + supervisor_control, + jabber_write, + tasks, ); let handle = tokio::spawn(async move { actor.run().await }); diff --git a/luz/src/connection/write.rs b/luz/src/connection/write.rs index 9c01519..09638a8 100644 --- a/luz/src/connection/write.rs +++ b/luz/src/connection/write.rs @@ -19,10 +19,10 @@ pub struct Write { pub struct WriteMessage { stanza: Stanza, - respond_to: oneshot::Sender>, + pub respond_to: oneshot::Sender>, } -enum WriteControl { +pub enum WriteControl { Disconnect, Abort(oneshot::Sender>), } @@ -46,38 +46,66 @@ impl Write { Ok(self.stream.write(stanza).await?) } + async fn run_reconnected(mut self, retry_msg: WriteMessage) { + // try to retry sending the message that failed to send previously + let result = self.stream.write(&retry_msg.stanza).await; + match result { + Err(e) => match &e { + peanuts::Error::ReadError(_error) => { + // make sure message is not lost from error, supervisor handles retry and reporting + // TODO: upon reconnect, make sure we are not stuck in a reconnection loop + let _ = self.on_crash.send((retry_msg, self.stanza_receiver)); + return; + } + _ => { + let _ = retry_msg.respond_to.send(Err(e.into())); + } + }, + _ => { + let _ = retry_msg.respond_to.send(Ok(())); + } + } + // return to normal loop + self.run().await + } + async fn run(mut self) { loop { tokio::select! { Some(msg) = self.control_receiver.recv() => { match msg { WriteControl::Disconnect => { - // TODO: close the stanza_receiver channel and drain out all of the remaining stanzas to send + // close the stanza_receiver channel and drain out all of the remaining stanzas to send self.stanza_receiver.close(); // TODO: put this in some kind of function to avoid code duplication - for msg in self.stanza_receiver.recv().await { + while let Some(msg) = self.stanza_receiver.recv().await { let result = self.stream.write(&msg.stanza).await; match result { Err(e) => match &e { - peanuts::Error::ReadError(error) => { - // make sure message is not lost from error, supervisor handles retry and reporting - self.on_crash.send((msg, self.stanza_receiver)); + peanuts::Error::ReadError(_error) => { + // if connection lost during disconnection, just send lost connection error to the write requests + let _ = msg.respond_to.send(Err(Error::LostConnection)); + while let Some(msg) = self.stanza_receiver.recv().await { + let _ = msg.respond_to.send(Err(Error::LostConnection)); + } break; } + // otherwise complete sending all the stanzas currently in the queue _ => { - msg.respond_to.send(Err(e.into())); + let _ = msg.respond_to.send(Err(e.into())); } }, _ => { - msg.respond_to.send(Ok(())); + let _ = msg.respond_to.send(Ok(())); } } } - self.stream.try_close().await; + let _ = self.stream.try_close().await; break; }, + // in case of abort, stream is already fucked, just send the receiver ready for a reconnection at the same resource WriteControl::Abort(sender) => { - sender.send(self.stanza_receiver); + let _ = sender.send(self.stanza_receiver); break; }, } @@ -86,21 +114,20 @@ impl Write { let result = self.stream.write(&msg.stanza).await; match result { Err(e) => match &e { - peanuts::Error::ReadError(error) => { + peanuts::Error::ReadError(_error) => { // make sure message is not lost from error, supervisor handles retry and reporting - self.on_crash.send((msg, self.stanza_receiver)); + let _ = self.on_crash.send((msg, self.stanza_receiver)); break; } _ => { - msg.respond_to.send(Err(e.into())); + let _ = msg.respond_to.send(Err(e.into())); } }, _ => { - msg.respond_to.send(Ok(())); + let _ = msg.respond_to.send(Ok(())); } } }, - // TODO: check if this is ok to do else => break, } } @@ -128,7 +155,21 @@ impl DerefMut for WriteHandle { pub struct WriteControlHandle { sender: mpsc::Sender, - handle: JoinHandle<()>, + pub(crate) handle: JoinHandle<()>, +} + +impl Deref for WriteControlHandle { + type Target = mpsc::Sender; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} + +impl DerefMut for WriteControlHandle { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sender + } } impl WriteControlHandle { @@ -153,6 +194,23 @@ impl WriteControlHandle { ) } + pub fn reconnect_retry( + stream: BoundJabberWriter, + supervisor: oneshot::Sender<(WriteMessage, mpsc::Receiver)>, + retry_msg: WriteMessage, + stanza_receiver: mpsc::Receiver, + ) -> Self { + let (control_sender, control_receiver) = mpsc::channel(20); + + let actor = Write::new(stanza_receiver, control_receiver, stream, supervisor); + let handle = tokio::spawn(async move { actor.run_reconnected(retry_msg).await }); + + Self { + sender: control_sender, + handle, + } + } + pub fn reconnect( stream: BoundJabberWriter, supervisor: oneshot::Sender<(WriteMessage, mpsc::Receiver)>, diff --git a/luz/src/error.rs b/luz/src/error.rs index d9dfaba..2809e8d 100644 --- a/luz/src/error.rs +++ b/luz/src/error.rs @@ -5,6 +5,7 @@ pub enum Error { SQL(sqlx::Error), JID(jid::ParseError), AlreadyDisconnected, + LostConnection, } impl From for Error { diff --git a/luz/src/lib.rs b/luz/src/lib.rs index 0dfc30c..333d8eb 100644 --- a/luz/src/lib.rs +++ b/luz/src/lib.rs @@ -20,7 +20,7 @@ pub struct Luz { receiver: mpsc::Receiver, jid: Arc>, // TODO: use a dyn passwordprovider trait to avoid storing password in memory - password: String, + password: Arc, connected: Arc>>, db: SqlitePool, sender: mpsc::Sender, @@ -43,7 +43,7 @@ impl Luz { ) -> Self { Self { jid, - password, + password: Arc::new(password), connected, db, receiver, @@ -75,7 +75,7 @@ impl Luz { let mut domain = jid.domainpart.clone(); // TODO: check what happens upon reconnection with same resource (this is probably what one wants to do and why jid should be mutated from a bare jid to one with a resource) let streams_result = - jabber::connect_and_login(&mut jid, &self.password, &mut domain) + jabber::connect_and_login(&mut jid, &*self.password, &mut domain) .await; match streams_result { Ok(s) => { @@ -85,6 +85,8 @@ impl Luz { self.sender.clone(), self.db.clone(), shutdown_send, + self.jid.clone(), + self.password.clone(), ); self.connection_supervisor_shutdown = shutdown_recv; *self.connected.lock().await = Some((writer, supervisor));