Compare commits

...

14 Commits

Author SHA1 Message Date
cel 🌸 861db1197d feat(stanza): impl Clone for Stanza types 2025-02-25 23:30:27 +00:00
cel 🌸 4dac2dbe1d refactor(luz): error types 2025-02-25 23:29:44 +00:00
cel 🌸 d797061786 add `sqlx` feature to `jid` dependency 2025-02-25 20:52:19 +00:00
cel 🌸 76b00cd644 implement Clone for error types 2025-02-25 20:50:23 +00:00
cel 🌸 d30131e0fc implement Error for jabber crate error types 2025-02-25 20:31:10 +00:00
cel 🌸 53ea2951ae implement Error for stanza crate error types 2025-02-25 20:30:44 +00:00
cel 🌸 b859cd7f78 cleanup jabber crate 2025-02-25 19:50:46 +00:00
cel 🌸 ea87cc407c cleanup 2025-02-25 19:45:35 +00:00
cel 🌸 90a5af5c75 implement Error for stanza crate sasl error types 2025-02-25 19:45:20 +00:00
cel 🌸 4fe4ab9d83 implement Error for stanza crate error types 2025-02-25 19:11:25 +00:00
cel 🌸 3c412ea6b0 implement Error for jid crate error types 2025-02-25 18:52:14 +00:00
cel 🌸 20fc4b1966 feature gate sqlx for jid crate 2025-02-25 18:45:46 +00:00
cel 🌸 65e908e36c implement Into<Cow> for &JID 2025-02-25 18:43:12 +00:00
cel 🌸 eda4bd92ff fix crash by fusing oneshot 2025-02-24 10:03:12 +00:00
30 changed files with 682 additions and 412 deletions

View File

@ -30,6 +30,7 @@ futures = "0.3.31"
take_mut = "0.2.2" take_mut = "0.2.2"
pin-project-lite = "0.2.15" pin-project-lite = "0.2.15"
pin-project = "1.1.7" pin-project = "1.1.7"
thiserror = "2.0.11"
[dev-dependencies] [dev-dependencies]
test-log = { version = "0.2", features = ["trace"] } test-log = { version = "0.2", features = ["trace"] }

View File

@ -1,29 +1,17 @@
use std::{
borrow::Borrow,
future::Future,
pin::pin,
sync::Arc,
task::{ready, Poll},
};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig; use rsasl::config::SASLConfig;
use stanza::{ use stanza::{
client::Stanza,
sasl::Mechanisms, sasl::Mechanisms,
stream::{Feature, Features}, stream::{Feature, Features},
}; };
use tokio::sync::Mutex;
use crate::{ use crate::{
connection::{Tls, Unencrypted}, connection::{Tls, Unencrypted},
jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream}, jabber_stream::bound_stream::BoundJabberStream,
Connection, Error, JabberStream, Result, JID, Connection, Error, JabberStream, Result, JID,
}; };
pub async fn connect_and_login( pub async fn connect_and_login(
mut jid: &mut JID, jid: &mut JID,
password: impl AsRef<str>, password: impl AsRef<str>,
server: &mut String, server: &mut String,
) -> Result<BoundJabberStream<Tls>> { ) -> Result<BoundJabberStream<Tls>> {
@ -31,7 +19,8 @@ pub async fn connect_and_login(
None, None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?, jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.as_ref().to_string(), password.as_ref().to_string(),
)?; )
.map_err(|e| Error::SASL(e.into()))?;
let mut conn_state = Connecting::start(&server).await?; let mut conn_state = Connecting::start(&server).await?;
loop { loop {
match conn_state { match conn_state {
@ -120,9 +109,8 @@ pub enum InsecureConnecting {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{sync::Arc, time::Duration}; use std::time::Duration;
use futures::{SinkExt, StreamExt};
use jid::JID; use jid::JID;
use stanza::{ use stanza::{
client::{ client::{
@ -132,7 +120,7 @@ mod tests {
xep_0199::Ping, xep_0199::Ping,
}; };
use test_log::test; use test_log::test;
use tokio::{sync::Mutex, time::sleep}; use tokio::time::sleep;
use tracing::info; use tracing::info;
use super::connect_and_login; use super::connect_and_login;
@ -140,7 +128,7 @@ mod tests {
#[test(tokio::test)] #[test(tokio::test)]
async fn login() { async fn login() {
let mut jid: JID = "test@blos.sm".try_into().unwrap(); let mut jid: JID = "test@blos.sm".try_into().unwrap();
let client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string()) let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
.await .await
.unwrap(); .unwrap();
sleep(Duration::from_secs(5)).await sleep(Duration::from_secs(5)).await

View File

@ -1,9 +1,7 @@
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::str; use std::str;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use rsasl::config::SASLConfig;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector; use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls // TODO: use rustls

View File

@ -1,87 +1,58 @@
use std::str::Utf8Error; use std::str::Utf8Error;
use std::sync::Arc;
use jid::ParseError; use jid::ParseError;
use rsasl::mechname::MechanismNameError; use rsasl::mechname::MechanismNameError;
use stanza::client::error::Error as ClientError; use stanza::client::error::Error as ClientError;
use stanza::sasl::Failure; use stanza::sasl::Failure;
use stanza::stream::Error as StreamError; use stanza::stream::Error as StreamError;
use tokio::task::JoinError; use thiserror::Error;
#[derive(Debug)] #[derive(Error, Debug, Clone)]
pub enum Error { pub enum Error {
#[error("connection")]
Connection, Connection,
Utf8Decode, #[error("utf8 decode: {0}")]
Utf8Decode(#[from] Utf8Error),
#[error("negotiation")]
Negotiation, Negotiation,
#[error("tls required")]
TlsRequired, TlsRequired,
#[error("already connected with tls")]
AlreadyTls, AlreadyTls,
// TODO: specify unsupported feature
#[error("unsupported feature")]
Unsupported, Unsupported,
#[error("jid missing localpart")]
NoLocalpart, NoLocalpart,
AlreadyConnecting, #[error("received unexpected element: {0:?}")]
StreamClosed,
UnexpectedElement(peanuts::Element), UnexpectedElement(peanuts::Element),
XML(peanuts::Error), #[error("xml error: {0}")]
Deserialization(peanuts::DeserializeError), XML(#[from] peanuts::Error),
SASL(SASLError), #[error("sasl error: {0}")]
JID(ParseError), SASL(#[from] SASLError),
Authentication(Failure), #[error("jid error: {0}")]
ClientError(ClientError), JID(#[from] ParseError),
StreamError(StreamError), #[error("client stanza error: {0}")]
ClientError(#[from] ClientError),
#[error("stream error: {0}")]
StreamError(#[from] StreamError),
#[error("error missing")]
MissingError, MissingError,
Disconnected,
Connecting,
JoinError(JoinError),
} }
#[derive(Debug)] #[derive(Error, Debug, Clone)]
pub enum SASLError { pub enum SASLError {
SASL(rsasl::prelude::SASLError), #[error("sasl error: {0}")]
MechanismName(MechanismNameError), SASL(Arc<rsasl::prelude::SASLError>),
#[error("mechanism error: {0}")]
MechanismName(#[from] MechanismNameError),
#[error("authentication failure: {0}")]
Authentication(#[from] Failure),
} }
impl From<rsasl::prelude::SASLError> for Error { impl From<rsasl::prelude::SASLError> for SASLError {
fn from(e: rsasl::prelude::SASLError) -> Self { fn from(e: rsasl::prelude::SASLError) -> Self {
Self::SASL(SASLError::SASL(e)) Self::SASL(Arc::new(e))
}
}
impl From<JoinError> for Error {
fn from(e: JoinError) -> Self {
Self::JoinError(e)
}
}
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)
}
}
impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e))
}
}
impl From<SASLError> for Error {
fn from(e: SASLError) -> Self {
Self::SASL(e)
}
}
impl From<Utf8Error> for Error {
fn from(_e: Utf8Error) -> Self {
Self::Utf8Decode
}
}
impl From<peanuts::Error> for Error {
fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
impl From<ParseError> for Error {
fn from(e: ParseError) -> Self {
Self::JID(e)
} }
} }

View File

@ -1,10 +1,8 @@
use std::pin::pin;
use std::str::{self, FromStr}; use std::str::{self, FromStr};
use std::sync::Arc; use std::sync::Arc;
use futures::{sink, stream, StreamExt};
use jid::JID; use jid::JID;
use peanuts::element::{FromContent, IntoElement}; use peanuts::element::IntoElement;
use peanuts::{Reader, Writer}; use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use stanza::bind::{Bind, BindType, FullJidType, ResourceType}; use stanza::bind::{Bind, BindType, FullJidType, ResourceType};
@ -135,13 +133,16 @@ where
let sasl = SASLClient::new(sasl_config); let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new(); let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms { for mechanism in &mechanisms.mechanisms {
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) offered_mechs
.push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)
} }
debug!("{:?}", offered_mechs); debug!("{:?}", offered_mechs);
let mut session = sasl.start_suggested(&offered_mechs)?; let mut session = sasl
.start_suggested(&offered_mechs)
.map_err(|e| Error::SASL(e.into()))?;
let selected_mechanism = session.get_mechname().as_str().to_owned(); let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism); debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>> = None; let mut data: Option<Vec<u8>>;
if !session.are_we_first() { if !session.are_we_first() {
// if not first mention the mechanism then get challenge data // if not first mention the mechanism then get challenge data
@ -176,7 +177,7 @@ where
ServerResponse::Success(success) => { ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec()) data = success.clone().map(|success| success.as_bytes().to_vec())
} }
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
} }
debug!("we went first"); debug!("we went first");
} }
@ -207,7 +208,7 @@ where
ServerResponse::Success(success) => { ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec()) data = success.clone().map(|success| success.as_bytes().to_vec())
} }
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
} }
} }
} }
@ -409,13 +410,7 @@ impl std::fmt::Debug for JabberStream<Unencrypted> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration;
use super::*;
use crate::connection::Connection;
use futures::sink;
use test_log::test; use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)] #[test(tokio::test)]
async fn start_stream() { async fn start_stream() {

View File

@ -1,9 +1,6 @@
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use peanuts::{Reader, Writer}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use crate::Error;
use super::{JabberReader, JabberStream, JabberWriter}; use super::{JabberReader, JabberStream, JabberWriter};

View File

@ -8,7 +8,6 @@ pub mod error;
pub mod jabber_stream; pub mod jabber_stream;
pub use connection::Connection; pub use connection::Connection;
use connection::Tls;
pub use error::Error; pub use error::Error;
pub use jabber_stream::JabberStream; pub use jabber_stream::JabberStream;
pub use jid::JID; pub use jid::JID;

View File

@ -3,5 +3,8 @@ name = "jid"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[features]
sqlx = ["dep:sqlx"]
[dependencies] [dependencies]
sqlx = { version = "0.8.3", features = ["sqlite"] } sqlx = { version = "0.8.3", features = ["sqlite"], optional = true }

View File

@ -1,8 +1,9 @@
use std::{error::Error, fmt::Display, str::FromStr}; use std::{borrow::Cow, error::Error, fmt::Display, str::FromStr};
#[cfg(feature = "sqlx")]
use sqlx::Sqlite; use sqlx::Sqlite;
#[derive(PartialEq, Debug, Clone, sqlx::Type, sqlx::Encode, Eq, Hash)] #[derive(PartialEq, Debug, Clone, Eq, Hash)]
pub struct JID { pub struct JID {
// TODO: validate localpart (length, char] // TODO: validate localpart (length, char]
pub localpart: Option<String>, pub localpart: Option<String>,
@ -10,13 +11,36 @@ pub struct JID {
pub resourcepart: Option<String>, pub resourcepart: Option<String>,
} }
// TODO: feature gate impl<'a> Into<Cow<'a, str>> for &'a JID {
fn into(self) -> Cow<'a, str> {
let a = self.to_string();
Cow::Owned(a)
}
}
impl Display for JID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(localpart) = &self.localpart {
f.write_str(localpart)?;
f.write_str("@")?;
}
f.write_str(&self.domainpart)?;
if let Some(resourcepart) = &self.resourcepart {
f.write_str("/")?;
f.write_str(resourcepart)?;
}
Ok(())
}
}
#[cfg(feature = "sqlx")]
impl sqlx::Type<Sqlite> for JID { impl sqlx::Type<Sqlite> for JID {
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo { fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
<&str as sqlx::Type<Sqlite>>::type_info() <&str as sqlx::Type<Sqlite>>::type_info()
} }
} }
#[cfg(feature = "sqlx")]
impl sqlx::Decode<'_, Sqlite> for JID { impl sqlx::Decode<'_, Sqlite> for JID {
fn decode( fn decode(
value: <Sqlite as sqlx::Database>::ValueRef<'_>, value: <Sqlite as sqlx::Database>::ValueRef<'_>,
@ -27,6 +51,7 @@ impl sqlx::Decode<'_, Sqlite> for JID {
} }
} }
#[cfg(feature = "sqlx")]
impl sqlx::Encode<'_, Sqlite> for JID { impl sqlx::Encode<'_, Sqlite> for JID {
fn encode_by_ref( fn encode_by_ref(
&self, &self,
@ -37,12 +62,24 @@ impl sqlx::Encode<'_, Sqlite> for JID {
} }
} }
#[derive(Debug, Clone)]
pub enum JIDError { pub enum JIDError {
NoResourcePart, NoResourcePart,
ParseError(ParseError), ParseError(ParseError),
} }
#[derive(Debug)] impl Display for JIDError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JIDError::NoResourcePart => f.write_str("resourcepart missing"),
JIDError::ParseError(parse_error) => parse_error.fmt(f),
}
}
}
impl Error for JIDError {}
#[derive(Debug, Clone)]
pub enum ParseError { pub enum ParseError {
Empty, Empty,
Malformed(String), Malformed(String),
@ -147,21 +184,6 @@ impl TryFrom<&str> for JID {
} }
} }
impl std::fmt::Display for JID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}{}{}",
self.localpart.clone().map(|l| l + "@").unwrap_or_default(),
self.domainpart,
self.resourcepart
.clone()
.map(|r| "/".to_owned() + &r)
.unwrap_or_default()
)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

1
luz/.gitignore vendored
View File

@ -1 +1,2 @@
luz.db luz.db
.sqlx/

View File

@ -7,7 +7,7 @@ edition = "2021"
futures = "0.3.31" futures = "0.3.31"
jabber = { version = "0.1.0", path = "../jabber" } jabber = { version = "0.1.0", path = "../jabber" }
peanuts = { version = "0.1.0", path = "../../peanuts" } peanuts = { version = "0.1.0", path = "../../peanuts" }
jid = { version = "0.1.0", path = "../jid" } jid = { version = "0.1.0", path = "../jid", features = ["sqlx"] }
sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "uuid"] } sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "uuid"] }
stanza = { version = "0.1.0", path = "../stanza" } stanza = { version = "0.1.0", path = "../stanza" }
tokio = "1.42.0" tokio = "1.42.0"
@ -16,3 +16,4 @@ tokio-util = "0.7.13"
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = "0.3.19" tracing-subscriber = "0.3.19"
uuid = { version = "1.13.1", features = ["v4"] } uuid = { version = "1.13.1", features = ["v4"] }
thiserror = "2.0.11"

View File

@ -20,7 +20,7 @@ use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage};
use crate::{ use crate::{
db::Db, db::Db,
error::{Error, Reason}, error::{Error, ReadError, WriteError},
UpdateMessage, UpdateMessage,
}; };
@ -36,7 +36,7 @@ pub struct Supervisor {
tokio::task::JoinSet<()>, tokio::task::JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
)>, )>,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
writer_handle: WriteControlHandle, writer_handle: WriteControlHandle,
@ -62,7 +62,7 @@ pub enum State {
tokio::task::JoinSet<()>, tokio::task::JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
), ),
), ),
} }
@ -77,7 +77,7 @@ impl Supervisor {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
)>, )>,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
writer_handle: WriteControlHandle, writer_handle: WriteControlHandle,
@ -180,7 +180,7 @@ impl Supervisor {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_state.close(); write_state.close();
while let Some(msg) = write_state.recv().await { while let Some(msg) = write_state.recv().await {
let _ = msg.respond_to.send(Err(Reason::LostConnection)); let _ = msg.respond_to.send(Err(WriteError::LostConnection));
} }
// TODO: is this the correct error? // TODO: is this the correct error?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await; let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -227,9 +227,9 @@ impl Supervisor {
Err(e) => { Err(e) => {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_recv.close(); write_recv.close();
let _ = write_msg.respond_to.send(Err(Reason::LostConnection)); let _ = write_msg.respond_to.send(Err(WriteError::LostConnection));
while let Some(msg) = write_recv.recv().await { while let Some(msg) = write_recv.recv().await {
let _ = msg.respond_to.send(Err(Reason::LostConnection)); let _ = msg.respond_to.send(Err(WriteError::LostConnection));
} }
// TODO: is this the correct error to send? // TODO: is this the correct error to send?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await; let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -278,10 +278,10 @@ impl Supervisor {
// if reconnection failure, respond to all current messages with lost connection error. // if reconnection failure, respond to all current messages with lost connection error.
write_receiver.close(); write_receiver.close();
if let Some(msg) = retry_msg { if let Some(msg) = retry_msg {
msg.respond_to.send(Err(Reason::LostConnection)); msg.respond_to.send(Err(WriteError::LostConnection));
} }
while let Some(msg) = write_receiver.recv().await { while let Some(msg) = write_receiver.recv().await {
msg.respond_to.send(Err(Reason::LostConnection)); msg.respond_to.send(Err(WriteError::LostConnection));
} }
// TODO: is this the correct error? // TODO: is this the correct error?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await; let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -342,7 +342,7 @@ impl SupervisorHandle {
on_shutdown: oneshot::Sender<()>, on_shutdown: oneshot::Sender<()>,
jid: Arc<Mutex<JID>>, jid: Arc<Mutex<JID>>,
password: Arc<String>, password: Arc<String>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
) -> (WriteHandle, Self) { ) -> (WriteHandle, Self) {
let (command_sender, command_receiver) = mpsc::channel(20); let (command_sender, command_receiver) = mpsc::channel(20);
let (writer_error_sender, writer_error_receiver) = oneshot::channel(); let (writer_error_sender, writer_error_receiver) = oneshot::channel();

View File

@ -18,16 +18,13 @@ use uuid::Uuid;
use crate::{ use crate::{
chat::{Body, Message}, chat::{Body, Message},
db::Db, db::Db,
error::{Error, IqError, PresenceError, Reason, RecvMessageError}, error::{Error, IqError, MessageRecvError, PresenceError, ReadError, RosterError},
presence::{Offline, Online, Presence, Show}, presence::{Offline, Online, Presence, Show},
roster::Contact, roster::Contact,
UpdateMessage, UpdateMessage,
}; };
use super::{ use super::{write::WriteHandle, SupervisorCommand};
write::{WriteHandle, WriteMessage},
SupervisorCommand,
};
pub struct Read { pub struct Read {
control_receiver: mpsc::Receiver<ReadControl>, control_receiver: mpsc::Receiver<ReadControl>,
@ -38,7 +35,7 @@ pub struct Read {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
)>, )>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
@ -48,7 +45,7 @@ pub struct Read {
disconnecting: bool, disconnecting: bool,
disconnect_timedout: oneshot::Receiver<()>, disconnect_timedout: oneshot::Receiver<()>,
// TODO: use proper stanza ids // TODO: use proper stanza ids
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
} }
impl Read { impl Read {
@ -61,7 +58,7 @@ impl Read {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
)>, )>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
@ -69,9 +66,9 @@ impl Read {
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
write_handle: WriteHandle, write_handle: WriteHandle,
tasks: JoinSet<()>, tasks: JoinSet<()>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
) -> Self { ) -> Self {
let (send, recv) = oneshot::channel(); let (_send, recv) = oneshot::channel();
Self { Self {
control_receiver, control_receiver,
stream, stream,
@ -162,7 +159,7 @@ impl Read {
// when it aborts, must clear iq map no matter what // when it aborts, must clear iq map no matter what
let mut iqs = self.pending_iqs.lock().await; let mut iqs = self.pending_iqs.lock().await;
for (_id, sender) in iqs.drain() { for (_id, sender) in iqs.drain() {
let _ = sender.send(Err(Reason::LostConnection)); let _ = sender.send(Err(ReadError::LostConnection));
} }
} }
} }
@ -178,7 +175,7 @@ async fn handle_stanza(
db: Db, db: Db,
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
write_handle: WriteHandle, write_handle: WriteHandle,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
) { ) {
match stanza { match stanza {
Stanza::Message(stanza_message) => { Stanza::Message(stanza_message) => {
@ -207,7 +204,9 @@ async fn handle_stanza(
if let Err(e) = result { if let Err(e) = result {
tracing::error!("messagecreate"); tracing::error!("messagecreate");
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::CacheUpdate(e.into()))) .send(UpdateMessage::Error(Error::MessageRecv(
MessageRecvError::MessageHistory(e.into()),
)))
.await; .await;
} }
let _ = update_sender let _ = update_sender
@ -215,8 +214,8 @@ async fn handle_stanza(
.await; .await;
} else { } else {
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::RecvMessage( .send(UpdateMessage::Error(Error::MessageRecv(
RecvMessageError::MissingFrom, MessageRecvError::MissingFrom,
))) )))
.await; .await;
} }
@ -229,9 +228,16 @@ async fn handle_stanza(
stanza::client::presence::PresenceType::Error => { stanza::client::presence::PresenceType::Error => {
// TODO: is there any other information that should go with the error? also MUST have an error, otherwise it's a different error. maybe it shoulnd't be an option. // TODO: is there any other information that should go with the error? also MUST have an error, otherwise it's a different error. maybe it shoulnd't be an option.
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::Presence(PresenceError::Error( .send(UpdateMessage::Error(Error::Presence(
Reason::Stanza(presence.errors.first().cloned()), // TODO: ughhhhhhhhhhhhh these stanza errors should probably just have an option, and custom display
)))) PresenceError::StanzaError(
presence
.errors
.first()
.cloned()
.expect("error MUST have error"),
),
)))
.await; .await;
} }
// should not happen (error to server) // should not happen (error to server)
@ -329,8 +335,8 @@ async fn handle_stanza(
let contact: Contact = item.into(); let contact: Contact = item.into();
if let Err(e) = db.upsert_contact(contact.clone()).await { if let Err(e) = db.upsert_contact(contact.clone()).await {
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::CacheUpdate( .send(UpdateMessage::Error(Error::Roster(
e.into(), RosterError::Cache(e.into()),
))) )))
.await; .await;
} }
@ -381,7 +387,7 @@ pub enum ReadControl {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
)>, )>,
), ),
} }
@ -414,13 +420,13 @@ impl ReadControlHandle {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
)>, )>,
db: Db, db: Db,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
jabber_write: WriteHandle, jabber_write: WriteHandle,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
) -> Self { ) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20); let (control_sender, control_receiver) = mpsc::channel(20);
@ -451,14 +457,14 @@ impl ReadControlHandle {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
)>, )>,
db: Db, db: Db,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
jabber_write: WriteHandle, jabber_write: WriteHandle,
tasks: JoinSet<()>, tasks: JoinSet<()>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
) -> Self { ) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20); let (control_sender, control_receiver) = mpsc::channel(20);

View File

@ -7,7 +7,7 @@ use tokio::{
task::JoinHandle, task::JoinHandle,
}; };
use crate::error::{Error, Reason}; use crate::error::WriteError;
// actor that receives jabber stanzas to write, and if there is an error, sends a message back to the supervisor then aborts, so the supervisor can spawn a new stream. // actor that receives jabber stanzas to write, and if there is an error, sends a message back to the supervisor then aborts, so the supervisor can spawn a new stream.
pub struct Write { pub struct Write {
@ -17,9 +17,10 @@ pub struct Write {
on_crash: oneshot::Sender<(WriteMessage, mpsc::Receiver<WriteMessage>)>, on_crash: oneshot::Sender<(WriteMessage, mpsc::Receiver<WriteMessage>)>,
} }
#[derive(Debug)]
pub struct WriteMessage { pub struct WriteMessage {
pub stanza: Stanza, pub stanza: Stanza,
pub respond_to: oneshot::Sender<Result<(), Reason>>, pub respond_to: oneshot::Sender<Result<(), WriteError>>,
} }
pub enum WriteControl { pub enum WriteControl {
@ -84,9 +85,9 @@ impl Write {
Err(e) => match &e { Err(e) => match &e {
peanuts::Error::ReadError(_error) => { peanuts::Error::ReadError(_error) => {
// if connection lost during disconnection, just send lost connection error to the write requests // if connection lost during disconnection, just send lost connection error to the write requests
let _ = msg.respond_to.send(Err(Reason::LostConnection)); let _ = msg.respond_to.send(Err(WriteError::LostConnection));
while let Some(msg) = self.stanza_receiver.recv().await { while let Some(msg) = self.stanza_receiver.recv().await {
let _ = msg.respond_to.send(Err(Reason::LostConnection)); let _ = msg.respond_to.send(Err(WriteError::LostConnection));
} }
break; break;
} }
@ -140,16 +141,16 @@ pub struct WriteHandle {
} }
impl WriteHandle { impl WriteHandle {
pub async fn write(&self, stanza: Stanza) -> Result<(), Reason> { pub async fn write(&self, stanza: Stanza) -> Result<(), WriteError> {
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
self.send(WriteMessage { self.send(WriteMessage {
stanza, stanza,
respond_to: send, respond_to: send,
}) })
.await .await
.map_err(|_| Reason::ChannelSend)?; .map_err(|e| WriteError::Actor(e.into()))?;
// TODO: timeout // TODO: timeout
recv.await? recv.await.map_err(|e| WriteError::Actor(e.into()))?
} }
} }

View File

@ -1,138 +1,158 @@
use stanza::client::Stanza; use std::sync::Arc;
use tokio::sync::oneshot::{self};
#[derive(Debug)] use stanza::client::Stanza;
use thiserror::Error;
use tokio::sync::{mpsc::error::SendError, oneshot::error::RecvError};
#[derive(Debug, Error, Clone)]
pub enum Error { pub enum Error {
#[error("already connected")]
AlreadyConnected, AlreadyConnected,
// TODO: change to Connecting(ConnectingError) // TODO: change to Connecting(ConnectingError)
Connection(ConnectionError), #[error("connecting: {0}")]
Presence(PresenceError), Connecting(#[from] ConnectionError),
SetStatus(Reason), #[error("presence: {0}")]
Roster(Reason), Presence(#[from] PresenceError),
Stream(stanza::stream::Error), #[error("set status: {0}")]
SendMessage(Reason), SetStatus(#[from] StatusError),
RecvMessage(RecvMessageError), // TODO: have different ones for get/update/set
#[error("roster: {0}")]
Roster(RosterError),
#[error("stream error: {0}")]
Stream(#[from] stanza::stream::Error),
#[error("message send error: {0}")]
MessageSend(MessageSendError),
#[error("message receive error: {0}")]
MessageRecv(MessageRecvError),
#[error("already disconnected")]
AlreadyDisconnected, AlreadyDisconnected,
#[error("lost connection")]
LostConnection, LostConnection,
// TODO: should all cache update errors include the context? // TODO: Display for Content
CacheUpdate(Reason), #[error("received unrecognized/unsupported content: {0:?}")]
UnrecognizedContent(peanuts::element::Content), UnrecognizedContent(peanuts::element::Content),
#[error("iq receive error: {0}")]
Iq(IqError), Iq(IqError),
Cloned, #[error("disconnected")]
Disconnected,
} }
// TODO: this is horrifying, maybe just use tracing to forward error events??? #[derive(Debug, Error, Clone)]
impl Clone for Error { pub enum MessageSendError {
fn clone(&self) -> Self { #[error("could not add to message history: {0}")]
Error::Cloned MessageHistory(#[from] DatabaseError),
}
} }
#[derive(Debug)] #[derive(Debug, Error, Clone)]
pub enum PresenceError { pub enum PresenceError {
Error(Reason), #[error("unsupported")]
Unsupported, Unsupported,
#[error("missing from")]
MissingFrom, MissingFrom,
#[error("stanza error: {0}")]
StanzaError(#[from] stanza::client::error::Error),
} }
#[derive(Debug)] #[derive(Debug, Error, Clone)]
// TODO: should probably have all iq query related errors here, including read, write, stanza error, etc.
pub enum IqError { pub enum IqError {
#[error("no iq with id matching `{0}`")]
NoMatchingId(String), NoMatchingId(String),
} }
#[derive(Debug)] #[derive(Debug, Error, Clone)]
pub enum RecvMessageError { pub enum MessageRecvError {
#[error("could not add to message history: {0}")]
MessageHistory(#[from] DatabaseError),
#[error("missing from")]
MissingFrom, MissingFrom,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Error)]
pub enum ConnectionError { pub enum ConnectionError {
ConnectionFailed(Reason), #[error("connection failed: {0}")]
RosterRetreival(Reason), ConnectionFailed(#[from] jabber::Error),
SendPresence(Reason), #[error("failed roster retreival: {0}")]
NoCachedStatus(Reason), RosterRetreival(#[from] RosterError),
#[error("failed to send available presence: {0}")]
SendPresence(#[from] WriteError),
#[error("cached status: {0}")]
StatusCacheError(#[from] DatabaseError),
} }
#[derive(Debug)] #[derive(Debug, Error, Clone)]
pub struct RosterError(pub Reason); pub enum RosterError {
#[error("cache: {0}")]
impl From<RosterError> for Error { Cache(#[from] DatabaseError),
fn from(e: RosterError) -> Self { #[error("stream write: {0}")]
Self::Roster(e.0) Write(#[from] WriteError),
} // TODO: display for stanza, to show as xml, same for read error types.
} #[error("unexpected reply: {0:?}")]
impl From<RosterError> for ConnectionError {
fn from(e: RosterError) -> Self {
Self::RosterRetreival(e.0)
}
}
pub struct StatusError(pub Reason);
impl From<StatusError> for Error {
fn from(e: StatusError) -> Self {
Error::SetStatus(e.0)
}
}
impl From<StatusError> for ConnectionError {
fn from(e: StatusError) -> Self {
Self::SendPresence(e.0)
}
}
#[derive(Debug)]
pub enum Reason {
// TODO: organisastion of error into internal error thing
Timeout,
Stream(stanza::stream_error::Error),
Stanza(Option<stanza::client::error::Error>),
Jabber(jabber::Error),
XML(peanuts::Error),
SQL(sqlx::Error),
// JID(jid::ParseError),
LostConnection,
OneshotRecv(oneshot::error::RecvError),
UnexpectedStanza(Stanza), UnexpectedStanza(Stanza),
Disconnected, #[error("stream read: {0}")]
ChannelSend, Read(#[from] ReadError),
Cloned, #[error("stanza error: {0}")]
StanzaError(#[from] stanza::client::error::Error),
} }
// TODO: same here #[derive(Debug, Error, Clone)]
impl Clone for Reason { #[error("database error: {0}")]
fn clone(&self) -> Self { pub struct DatabaseError(Arc<sqlx::Error>);
Reason::Cloned
}
}
impl From<oneshot::error::RecvError> for Reason { impl From<sqlx::Error> for DatabaseError {
fn from(e: oneshot::error::RecvError) -> Reason {
Self::OneshotRecv(e)
}
}
impl From<peanuts::Error> for Reason {
fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
// impl From<jid::ParseError> for Reason {
// fn from(e: jid::ParseError) -> Self {
// Self::JID(e)
// }
// }
impl From<sqlx::Error> for Reason {
fn from(e: sqlx::Error) -> Self { fn from(e: sqlx::Error) -> Self {
Self::SQL(e) Self(Arc::new(e))
} }
} }
impl From<jabber::Error> for Reason { #[derive(Debug, Error, Clone)]
fn from(e: jabber::Error) -> Self { pub enum StatusError {
Self::Jabber(e) #[error("cache: {0}")]
Cache(#[from] DatabaseError),
#[error("stream write: {0}")]
Write(#[from] WriteError),
}
#[derive(Debug, Error, Clone)]
pub enum WriteError {
#[error("xml: {0}")]
XML(#[from] peanuts::Error),
#[error("lost connection")]
LostConnection,
// TODO: should this be in writeerror or separate?
#[error("actor: {0}")]
Actor(#[from] ActorError),
#[error("disconnected")]
Disconnected,
}
// TODO: separate peanuts read and write error?
#[derive(Debug, Error, Clone)]
pub enum ReadError {
#[error("xml: {0}")]
XML(#[from] peanuts::Error),
#[error("lost connection")]
LostConnection,
}
#[derive(Debug, Error, Clone)]
pub enum ActorError {
#[error("receive timed out")]
Timeout,
#[error("could not send message to actor, channel closed")]
Send,
#[error("could not receive message from actor, channel closed")]
Receive,
}
impl<T> From<SendError<T>> for ActorError {
fn from(_e: SendError<T>) -> Self {
Self::Send
}
}
impl From<RecvError> for ActorError {
fn from(_e: RecvError) -> Self {
Self::Receive
} }
} }

View File

@ -7,7 +7,8 @@ use std::{
use chat::{Body, Chat, Message}; use chat::{Body, Chat, Message};
use connection::{write::WriteMessage, SupervisorSender}; use connection::{write::WriteMessage, SupervisorSender};
use db::Db; use db::Db;
use error::{ConnectionError, Reason, RosterError, StatusError}; use error::{ConnectionError, DatabaseError, ReadError, RosterError, StatusError, WriteError};
use futures::{future::Fuse, FutureExt};
use jabber::JID; use jabber::JID;
use presence::{Offline, Online, Presence}; use presence::{Offline, Online, Presence};
use roster::{Contact, ContactUpdate}; use roster::{Contact, ContactUpdate};
@ -20,7 +21,7 @@ use tokio::{
sync::{mpsc, oneshot, Mutex}, sync::{mpsc, oneshot, Mutex},
task::JoinSet, task::JoinSet,
}; };
use tracing::{debug, info, Instrument}; use tracing::{debug, info};
use user::User; use user::User;
use uuid::Uuid; use uuid::Uuid;
@ -43,11 +44,11 @@ pub struct Luz {
// TODO: use a dyn passwordprovider trait to avoid storing password in memory // TODO: use a dyn passwordprovider trait to avoid storing password in memory
password: Arc<String>, password: Arc<String>,
connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>, connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
db: Db, db: Db,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
/// if connection was shut down due to e.g. server shutdown, supervisor must be able to mark client as disconnected /// if connection was shut down due to e.g. server shutdown, supervisor must be able to mark client as disconnected
connection_supervisor_shutdown: oneshot::Receiver<()>, connection_supervisor_shutdown: Fuse<oneshot::Receiver<()>>,
// TODO: will need to have an auto reconnect state as well (e.g. in case server shut down, to try and reconnect later) // TODO: will need to have an auto reconnect state as well (e.g. in case server shut down, to try and reconnect later)
// TODO: will grow forever at this point, maybe not required as tasks will naturally shut down anyway? // TODO: will grow forever at this point, maybe not required as tasks will naturally shut down anyway?
tasks: JoinSet<()>, tasks: JoinSet<()>,
@ -60,7 +61,7 @@ impl Luz {
jid: Arc<Mutex<JID>>, jid: Arc<Mutex<JID>>,
password: String, password: String,
connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>, connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>,
connection_supervisor_shutdown: oneshot::Receiver<()>, connection_supervisor_shutdown: Fuse<oneshot::Receiver<()>>,
db: SqlitePool, db: SqlitePool,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
) -> Self { ) -> Self {
@ -82,9 +83,8 @@ impl Luz {
loop { loop {
let msg = tokio::select! { let msg = tokio::select! {
// this is okay, as when created the supervisor (and connection) doesn't exist, but a bit messy // this is okay, as when created the supervisor (and connection) doesn't exist, but a bit messy
// THIS IS NOT OKAY LOLLLL // THIS IS NOT OKAY LOLLLL - apparently fusing is the best option???
_ = &mut self.connection_supervisor_shutdown => { _ = &mut self.connection_supervisor_shutdown => {
info!("got this");
*self.connected.lock().await = None; *self.connected.lock().await = None;
continue; continue;
} }
@ -130,6 +130,7 @@ impl Luz {
self.password.clone(), self.password.clone(),
self.pending_iqs.clone(), self.pending_iqs.clone(),
); );
let shutdown_recv = shutdown_recv.fuse();
self.connection_supervisor_shutdown = shutdown_recv; self.connection_supervisor_shutdown = shutdown_recv;
// TODO: get roster and send initial presence // TODO: get roster and send initial presence
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
@ -158,8 +159,8 @@ impl Luz {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error( .send(UpdateMessage::Error(
Error::Connection( Error::Connecting(
ConnectionError::NoCachedStatus( ConnectionError::StatusCacheError(
e.into(), e.into(),
), ),
), ),
@ -169,7 +170,11 @@ impl Luz {
} }
}; };
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
CommandMessage::SetStatus(online.clone(), send) CommandMessage::SendPresence(
None,
Presence::Online(online.clone()),
send,
)
.handle_online( .handle_online(
writer.clone(), writer.clone(),
supervisor.sender(), supervisor.sender(),
@ -197,13 +202,13 @@ impl Luz {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error( .send(UpdateMessage::Error(
Error::Connection(e.into()), Error::Connecting(e.into()),
)) ))
.await; .await;
} }
}, },
Err(e) => { Err(e) => {
let _ = self.sender.send(UpdateMessage::Error(Error::Connection(ConnectionError::SendPresence(e.into())))).await; let _ = self.sender.send(UpdateMessage::Error(Error::Connecting(ConnectionError::SendPresence(WriteError::Actor(e.into()))))).await;
} }
} }
} }
@ -211,7 +216,7 @@ impl Luz {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error( .send(UpdateMessage::Error(
Error::Connection(e.into()), Error::Connecting(e.into()),
)) ))
.await; .await;
} }
@ -220,8 +225,12 @@ impl Luz {
Err(e) => { Err(e) => {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error(Error::Connection( .send(UpdateMessage::Error(Error::Connecting(
ConnectionError::RosterRetreival(e.into()), ConnectionError::RosterRetreival(
RosterError::Write(WriteError::Actor(
e.into(),
)),
),
))) )))
.await; .await;
} }
@ -229,7 +238,7 @@ impl Luz {
} }
Err(e) => { Err(e) => {
let _ = let _ =
self.sender.send(UpdateMessage::Error(Error::Connection( self.sender.send(UpdateMessage::Error(Error::Connecting(
ConnectionError::ConnectionFailed(e.into()), ConnectionError::ConnectionFailed(e.into()),
))); )));
} }
@ -237,7 +246,7 @@ impl Luz {
} }
}; };
} }
CommandMessage::Disconnect(_offline) => { CommandMessage::Disconnect(offline) => {
match self.connected.lock().await.as_mut() { match self.connected.lock().await.as_mut() {
None => { None => {
let _ = self let _ = self
@ -247,15 +256,19 @@ impl Luz {
} }
mut c => { mut c => {
// TODO: send unavailable presence // TODO: send unavailable presence
if let Some((_write_handle, supervisor_handle)) = c.take() { if let Some((write_handle, supervisor_handle)) = c.take() {
let offline_presence: stanza::client::presence::Presence =
offline.clone().into();
let stanza = Stanza::Presence(offline_presence);
// TODO: timeout and error check
write_handle.write(stanza).await;
let _ = supervisor_handle.send(SupervisorCommand::Disconnect).await; let _ = supervisor_handle.send(SupervisorCommand::Disconnect).await;
c = None; let _ = self.sender.send(UpdateMessage::Offline(offline)).await;
} else { } else {
unreachable!() unreachable!()
}; };
} }
} }
info!("lock released")
} }
_ => { _ => {
match self.connected.lock().await.as_ref() { match self.connected.lock().await.as_ref() {
@ -281,7 +294,7 @@ impl Luz {
impl CommandMessage { impl CommandMessage {
pub async fn handle_offline( pub async fn handle_offline(
mut self, self,
jid: Arc<Mutex<JID>>, jid: Arc<Mutex<JID>>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
@ -296,7 +309,7 @@ impl CommandMessage {
let _ = sender.send(Ok(roster)); let _ = sender.send(Ok(roster));
} }
Err(e) => { Err(e) => {
let _ = sender.send(Err(RosterError(e.into()))); let _ = sender.send(Err(RosterError::Cache(e.into())));
} }
} }
} }
@ -326,45 +339,48 @@ impl CommandMessage {
} }
// TODO: offline queue to modify roster // TODO: offline queue to modify roster
CommandMessage::AddContact(jid, sender) => { CommandMessage::AddContact(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(RosterError::Write(WriteError::Disconnected)));
} }
CommandMessage::BuddyRequest(jid, sender) => { CommandMessage::BuddyRequest(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
} }
CommandMessage::SubscriptionRequest(jid, sender) => { CommandMessage::SubscriptionRequest(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
} }
CommandMessage::AcceptBuddyRequest(jid, sender) => { CommandMessage::AcceptBuddyRequest(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
} }
CommandMessage::AcceptSubscriptionRequest(jid, sender) => { CommandMessage::AcceptSubscriptionRequest(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
} }
CommandMessage::UnsubscribeFromContact(jid, sender) => { CommandMessage::UnsubscribeFromContact(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
} }
CommandMessage::UnsubscribeContact(jid, sender) => { CommandMessage::UnsubscribeContact(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
} }
CommandMessage::UnfriendContact(jid, sender) => { CommandMessage::UnfriendContact(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
} }
CommandMessage::DeleteContact(jid, sender) => { CommandMessage::DeleteContact(jid, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(RosterError::Write(WriteError::Disconnected)));
} }
CommandMessage::UpdateContact(jid, contact_update, sender) => { CommandMessage::UpdateContact(jid, contact_update, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(RosterError::Write(WriteError::Disconnected)));
} }
CommandMessage::SetStatus(online, sender) => { CommandMessage::SetStatus(online, sender) => {
let result = db let result = db
.upsert_cached_status(online) .upsert_cached_status(online)
.await .await
.map_err(|e| StatusError(e.into())); .map_err(|e| StatusError::Cache(e.into()));
sender.send(result); sender.send(result);
} }
// TODO: offline message queue // TODO: offline message queue
CommandMessage::SendMessage(jid, body, sender) => { CommandMessage::SendMessage(jid, body, sender) => {
sender.send(Err(Reason::Disconnected)); sender.send(Err(WriteError::Disconnected));
}
CommandMessage::SendPresence(jid, presence, sender) => {
sender.send(Err(WriteError::Disconnected));
} }
} }
} }
@ -377,7 +393,7 @@ impl CommandMessage {
client_jid: Arc<Mutex<JID>>, client_jid: Arc<Mutex<JID>>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
) { ) {
match self { match self {
CommandMessage::Connect => unreachable!(), CommandMessage::Connect => unreachable!(),
@ -419,11 +435,12 @@ impl CommandMessage {
Ok(Ok(())) => info!("roster request sent"), Ok(Ok(())) => info!("roster request sent"),
Ok(Err(e)) => { Ok(Err(e)) => {
// TODO: log errors if fail to send // TODO: log errors if fail to send
let _ = result_sender.send(Err(RosterError(e.into()))); let _ = result_sender.send(Err(RosterError::Write(e.into())));
return; return;
} }
Err(e) => { Err(e) => {
let _ = result_sender.send(Err(RosterError(e.into()))); let _ = result_sender
.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
return; return;
} }
}; };
@ -443,23 +460,41 @@ impl CommandMessage {
items.into_iter().map(|item| item.into()).collect(); items.into_iter().map(|item| item.into()).collect();
if let Err(e) = db.replace_cached_roster(contacts.clone()).await { if let Err(e) = db.replace_cached_roster(contacts.clone()).await {
update_sender update_sender
.send(UpdateMessage::Error(Error::CacheUpdate(e.into()))) .send(UpdateMessage::Error(Error::Roster(RosterError::Cache(
e.into(),
))))
.await; .await;
}; };
result_sender.send(Ok(contacts)); result_sender.send(Ok(contacts));
return; return;
} }
ref s @ Stanza::Iq(Iq {
from: _,
ref id,
to: _,
r#type,
lang: _,
query: _,
ref errors,
}) if *id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() {
result_sender.send(Err(RosterError::StanzaError(error.clone())));
} else {
result_sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
}
return;
}
s => { s => {
result_sender.send(Err(RosterError(Reason::UnexpectedStanza(s)))); result_sender.send(Err(RosterError::UnexpectedStanza(s)));
return; return;
} }
}, },
Ok(Err(e)) => { Ok(Err(e)) => {
result_sender.send(Err(RosterError(e.into()))); result_sender.send(Err(RosterError::Read(e)));
return; return;
} }
Err(e) => { Err(e) => {
result_sender.send(Err(RosterError(e.into()))); result_sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
return; return;
} }
} }
@ -520,8 +555,8 @@ impl CommandMessage {
} }
// TODO: write_handle send helper function // TODO: write_handle send helper function
let result = write_handle.write(set_stanza).await; let result = write_handle.write(set_stanza).await;
if let Err(_) = result { if let Err(e) = result {
sender.send(result); sender.send(Err(RosterError::Write(e)));
return; return;
} }
let iq_result = recv.await; let iq_result = recv.await;
@ -540,24 +575,24 @@ impl CommandMessage {
sender.send(Ok(())); sender.send(Ok(()));
return; return;
} }
Stanza::Iq(Iq { ref s @ Stanza::Iq(Iq {
from: _, from: _,
id, ref id,
to: _, to: _,
r#type, r#type,
lang: _, lang: _,
query: _, query: _,
errors, ref errors,
}) if id == iq_id && r#type == IqType::Error => { }) if *id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() { if let Some(error) = errors.first() {
sender.send(Err(Reason::Stanza(Some(error.clone())))); sender.send(Err(RosterError::StanzaError(error.clone())));
} else { } else {
sender.send(Err(Reason::Stanza(None))); sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
} }
return; return;
} }
s => { s => {
sender.send(Err(Reason::UnexpectedStanza(s))); sender.send(Err(RosterError::UnexpectedStanza(s)));
return; return;
} }
}, },
@ -567,7 +602,7 @@ impl CommandMessage {
} }
}, },
Err(e) => { Err(e) => {
sender.send(Err(e.into())); sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
return; return;
} }
} }
@ -765,8 +800,8 @@ impl CommandMessage {
pending_iqs.lock().await.insert(iq_id.clone(), send); pending_iqs.lock().await.insert(iq_id.clone(), send);
} }
let result = write_handle.write(set_stanza).await; let result = write_handle.write(set_stanza).await;
if let Err(_) = result { if let Err(e) = result {
sender.send(result); sender.send(Err(RosterError::Write(e)));
return; return;
} }
let iq_result = recv.await; let iq_result = recv.await;
@ -785,24 +820,24 @@ impl CommandMessage {
sender.send(Ok(())); sender.send(Ok(()));
return; return;
} }
Stanza::Iq(Iq { ref s @ Stanza::Iq(Iq {
from: _, from: _,
id, ref id,
to: _, to: _,
r#type, r#type,
lang: _, lang: _,
query: _, query: _,
errors, ref errors,
}) if id == iq_id && r#type == IqType::Error => { }) if *id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() { if let Some(error) = errors.first() {
sender.send(Err(Reason::Stanza(Some(error.clone())))); sender.send(Err(RosterError::StanzaError(error.clone())));
} else { } else {
sender.send(Err(Reason::Stanza(None))); sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
} }
return; return;
} }
s => { s => {
sender.send(Err(Reason::UnexpectedStanza(s))); sender.send(Err(RosterError::UnexpectedStanza(s)));
return; return;
} }
}, },
@ -812,7 +847,7 @@ impl CommandMessage {
} }
}, },
Err(e) => { Err(e) => {
sender.send(Err(e.into())); sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
return; return;
} }
} }
@ -853,8 +888,8 @@ impl CommandMessage {
pending_iqs.lock().await.insert(iq_id.clone(), send); pending_iqs.lock().await.insert(iq_id.clone(), send);
} }
let result = write_handle.write(set_stanza).await; let result = write_handle.write(set_stanza).await;
if let Err(_) = result { if let Err(e) = result {
sender.send(result); sender.send(Err(RosterError::Write(e)));
return; return;
} }
let iq_result = recv.await; let iq_result = recv.await;
@ -873,24 +908,24 @@ impl CommandMessage {
sender.send(Ok(())); sender.send(Ok(()));
return; return;
} }
Stanza::Iq(Iq { ref s @ Stanza::Iq(Iq {
from: _, from: _,
id, ref id,
to: _, to: _,
r#type, r#type,
lang: _, lang: _,
query: _, query: _,
errors, ref errors,
}) if id == iq_id && r#type == IqType::Error => { }) if *id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() { if let Some(error) = errors.first() {
sender.send(Err(Reason::Stanza(Some(error.clone())))); sender.send(Err(RosterError::StanzaError(error.clone())));
} else { } else {
sender.send(Err(Reason::Stanza(None))); sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
} }
return; return;
} }
s => { s => {
sender.send(Err(Reason::UnexpectedStanza(s))); sender.send(Err(RosterError::UnexpectedStanza(s)));
return; return;
} }
}, },
@ -900,7 +935,7 @@ impl CommandMessage {
} }
}, },
Err(e) => { Err(e) => {
sender.send(Err(e.into())); sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
return; return;
} }
} }
@ -909,13 +944,16 @@ impl CommandMessage {
let result = db.upsert_cached_status(online.clone()).await; let result = db.upsert_cached_status(online.clone()).await;
if let Err(e) = result { if let Err(e) = result {
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::CacheUpdate(e.into()))) .send(UpdateMessage::Error(Error::SetStatus(StatusError::Cache(
e.into(),
))))
.await; .await;
} }
let result = write_handle let result = write_handle
.write(Stanza::Presence(online.into())) .write(Stanza::Presence(online.into()))
.await .await
.map_err(|e| StatusError(e)); .map_err(|e| StatusError::Write(e));
// .map_err(|e| StatusError::Write(e));
let _ = sender.send(result); let _ = sender.send(result);
} }
// TODO: offline message queue // TODO: offline message queue
@ -951,7 +989,9 @@ impl CommandMessage {
}; };
if let Err(e) = db.create_message(message, jid).await.map_err(|e| e.into()) if let Err(e) = db.create_message(message, jid).await.map_err(|e| e.into())
{ {
let _ = update_sender.send(UpdateMessage::Error(Error::CacheUpdate(e))); let _ = update_sender.send(UpdateMessage::Error(Error::MessageSend(
error::MessageSendError::MessageHistory(e),
)));
} }
let _ = sender.send(Ok(())); let _ = sender.send(Ok(()));
} }
@ -960,6 +1000,15 @@ impl CommandMessage {
} }
} }
} }
CommandMessage::SendPresence(jid, presence, sender) => {
let mut presence: stanza::client::presence::Presence = presence.into();
if let Some(jid) = jid {
presence.to = Some(jid);
};
let result = write_handle.write(Stanza::Presence(presence)).await;
// .map_err(|e| StatusError::Write(e));
let _ = sender.send(result);
}
} }
} }
} }
@ -994,16 +1043,18 @@ impl DerefMut for LuzHandle {
} }
impl LuzHandle { impl LuzHandle {
// TODO: database creation separate
pub async fn new( pub async fn new(
jid: JID, jid: JID,
password: String, password: String,
db: &str, db: &str,
) -> Result<(Self, mpsc::Receiver<UpdateMessage>), Reason> { ) -> Result<(Self, mpsc::Receiver<UpdateMessage>), DatabaseError> {
let db = SqlitePool::connect(db).await?; let db = SqlitePool::connect(db).await?;
let (command_sender, command_receiver) = mpsc::channel(20); let (command_sender, command_receiver) = mpsc::channel(20);
let (update_sender, update_receiver) = mpsc::channel(20); let (update_sender, update_receiver) = mpsc::channel(20);
// might be bad, first supervisor shutdown notification oneshot is never used (disgusting) // might be bad, first supervisor shutdown notification oneshot is never used (disgusting)
let (sup_send, sup_recv) = oneshot::channel(); let (sup_send, sup_recv) = oneshot::channel();
let mut sup_recv = sup_recv.fuse();
let actor = Luz::new( let actor = Luz::new(
command_sender.clone(), command_sender.clone(),
@ -1024,8 +1075,59 @@ impl LuzHandle {
update_receiver, update_receiver,
)) ))
} }
pub async fn connect(&self) {
self.send(CommandMessage::Connect).await;
} }
pub async fn disconnect(&self, offline: Offline) {
self.send(CommandMessage::Disconnect(offline)).await;
}
// pub async fn get_roster(&self) -> Result<Vec<Contact>, RosterError> {
// let (send, recv) = oneshot::channel();
// self.send(CommandMessage::GetRoster(send)).await.map_err(|e| RosterError::)?;
// Ok(recv.await?)
// }
// pub async fn get_chats(&self) -> Result<Vec<Chat>, Error> {}
// pub async fn get_chat(&self, jid: JID) -> Result<Chat, Error> {}
// pub async fn get_messages(&self, jid: JID) -> Result<Vec<Message>, Error> {}
// pub async fn delete_chat(&self, jid: JID) -> Result<(), Error> {}
// pub async fn delete_message(&self, id: Uuid) -> Result<(), Error> {}
// pub async fn get_user(&self, jid: JID) -> Result<User, Error> {}
// pub async fn add_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn buddy_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn subscription_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn accept_buddy_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn accept_subscription_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unsubscribe_from_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unsubscribe_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unfriend_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn delete_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn update_contact(&self, jid: JID, update: ContactUpdate) -> Result<(), Error> {}
// pub async fn set_status(&self, online: Online) -> Result<(), Error> {}
// pub async fn send_message(&self, jid: JID, body: Body) -> Result<(), Error> {}
}
// TODO: generate methods for each with a macro
pub enum CommandMessage { pub enum CommandMessage {
// TODO: login invisible xep-0186 // TODO: login invisible xep-0186
/// connect to XMPP chat server. gets roster and publishes initial presence. /// connect to XMPP chat server. gets roster and publishes initial presence.
@ -1036,46 +1138,51 @@ pub enum CommandMessage {
GetRoster(oneshot::Sender<Result<Vec<Contact>, RosterError>>), GetRoster(oneshot::Sender<Result<Vec<Contact>, RosterError>>),
/// get all chats. chat will include 10 messages in their message Vec (enough for chat previews) /// get all chats. chat will include 10 messages in their message Vec (enough for chat previews)
// TODO: paging and filtering // TODO: paging and filtering
GetChats(oneshot::Sender<Result<Vec<Chat>, Reason>>), GetChats(oneshot::Sender<Result<Vec<Chat>, DatabaseError>>),
/// get a specific chat by jid /// get a specific chat by jid
GetChat(JID, oneshot::Sender<Result<Chat, Reason>>), GetChat(JID, oneshot::Sender<Result<Chat, DatabaseError>>),
/// get message history for chat (does appropriate mam things) /// get message history for chat (does appropriate mam things)
// TODO: paging and filtering // TODO: paging and filtering
GetMessages(JID, oneshot::Sender<Result<Vec<Message>, Reason>>), GetMessages(JID, oneshot::Sender<Result<Vec<Message>, DatabaseError>>),
/// delete a chat from your chat history, along with all the corresponding messages /// delete a chat from your chat history, along with all the corresponding messages
DeleteChat(JID, oneshot::Sender<Result<(), Reason>>), DeleteChat(JID, oneshot::Sender<Result<(), DatabaseError>>),
/// delete a message from your chat history /// delete a message from your chat history
DeleteMessage(Uuid, oneshot::Sender<Result<(), Reason>>), DeleteMessage(Uuid, oneshot::Sender<Result<(), DatabaseError>>),
/// get a user from your users database /// get a user from your users database
GetUser(JID, oneshot::Sender<Result<User, Reason>>), GetUser(JID, oneshot::Sender<Result<User, DatabaseError>>),
/// add a contact to your roster, with a status of none, no subscriptions. /// add a contact to your roster, with a status of none, no subscriptions.
// TODO: for all these, consider returning with oneshot::Sender<Result<(), Error>> AddContact(JID, oneshot::Sender<Result<(), RosterError>>),
AddContact(JID, oneshot::Sender<Result<(), Reason>>),
/// send a friend request i.e. a subscription request with a subscription pre-approval. if not already added to roster server adds to roster. /// send a friend request i.e. a subscription request with a subscription pre-approval. if not already added to roster server adds to roster.
BuddyRequest(JID, oneshot::Sender<Result<(), Reason>>), BuddyRequest(JID, oneshot::Sender<Result<(), WriteError>>),
/// send a subscription request, without pre-approval. if not already added to roster server adds to roster. /// send a subscription request, without pre-approval. if not already added to roster server adds to roster.
SubscriptionRequest(JID, oneshot::Sender<Result<(), Reason>>), SubscriptionRequest(JID, oneshot::Sender<Result<(), WriteError>>),
/// accept a friend request by accepting a pending subscription and sending a subscription request back. if not already added to roster adds to roster. /// accept a friend request by accepting a pending subscription and sending a subscription request back. if not already added to roster adds to roster.
AcceptBuddyRequest(JID, oneshot::Sender<Result<(), Reason>>), AcceptBuddyRequest(JID, oneshot::Sender<Result<(), WriteError>>),
/// accept a pending subscription and doesn't send a subscription request back. if not already added to roster adds to roster. /// accept a pending subscription and doesn't send a subscription request back. if not already added to roster adds to roster.
AcceptSubscriptionRequest(JID, oneshot::Sender<Result<(), Reason>>), AcceptSubscriptionRequest(JID, oneshot::Sender<Result<(), WriteError>>),
/// unsubscribe to a contact, but don't remove their subscription. /// unsubscribe to a contact, but don't remove their subscription.
UnsubscribeFromContact(JID, oneshot::Sender<Result<(), Reason>>), UnsubscribeFromContact(JID, oneshot::Sender<Result<(), WriteError>>),
/// stop a contact from being subscribed, but stay subscribed to the contact. /// stop a contact from being subscribed, but stay subscribed to the contact.
UnsubscribeContact(JID, oneshot::Sender<Result<(), Reason>>), UnsubscribeContact(JID, oneshot::Sender<Result<(), WriteError>>),
/// remove subscriptions to and from contact, but keep in roster. /// remove subscriptions to and from contact, but keep in roster.
UnfriendContact(JID, oneshot::Sender<Result<(), Reason>>), UnfriendContact(JID, oneshot::Sender<Result<(), WriteError>>),
/// remove a contact from the contact list. will remove subscriptions if not already done then delete contact from roster. /// remove a contact from the contact list. will remove subscriptions if not already done then delete contact from roster.
DeleteContact(JID, oneshot::Sender<Result<(), Reason>>), DeleteContact(JID, oneshot::Sender<Result<(), RosterError>>),
/// update contact. contact details will be overwritten with the contents of the contactupdate struct. /// update contact. contact details will be overwritten with the contents of the contactupdate struct.
UpdateContact(JID, ContactUpdate, oneshot::Sender<Result<(), Reason>>), UpdateContact(JID, ContactUpdate, oneshot::Sender<Result<(), RosterError>>),
/// set online status. if disconnected, will be cached so when client connects, will be sent as the initial presence. /// set online status. if disconnected, will be cached so when client connects, will be sent as the initial presence.
SetStatus(Online, oneshot::Sender<Result<(), StatusError>>), SetStatus(Online, oneshot::Sender<Result<(), StatusError>>),
/// send presence stanza
SendPresence(
Option<JID>,
Presence,
oneshot::Sender<Result<(), WriteError>>,
),
/// send a directed presence (usually to a non-contact). /// send a directed presence (usually to a non-contact).
// TODO: should probably make it so people can add non-contact auto presence sharing in the client (most likely through setting an internal setting) // TODO: should probably make it so people can add non-contact auto presence sharing in the client (most likely through setting an internal setting)
/// send a message to a jid (any kind of jid that can receive a message, e.g. a user or a /// send a message to a jid (any kind of jid that can receive a message, e.g. a user or a
/// chatroom). if disconnected, will be cached so when client connects, message will be sent. /// chatroom). if disconnected, will be cached so when client connects, message will be sent.
SendMessage(JID, Body, oneshot::Sender<Result<(), Reason>>), SendMessage(JID, Body, oneshot::Sender<Result<(), WriteError>>),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View File

@ -12,9 +12,13 @@ use tracing::info;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let db = SqlitePool::connect("./luz.db").await.unwrap(); let (luz, mut recv) = LuzHandle::new(
let (luz, mut recv) = "test@blos.sm".try_into().unwrap(),
LuzHandle::new("test@blos.sm".try_into().unwrap(), "slayed".to_string(), db); "slayed".to_string(),
"./luz.db",
)
.await
.unwrap();
tokio::spawn(async move { tokio::spawn(async move {
while let Some(msg) = recv.recv().await { while let Some(msg) = recv.recv().await {

View File

@ -91,3 +91,31 @@ impl From<Online> for stanza::client::presence::Presence {
} }
} }
} }
impl From<Offline> for stanza::client::presence::Presence {
fn from(value: Offline) -> Self {
Self {
from: None,
id: None,
to: None,
r#type: Some(stanza::client::presence::PresenceType::Unavailable),
lang: None,
show: None,
status: value.status.map(|status| stanza::client::presence::Status {
lang: None,
status: String1024(status),
}),
priority: None,
errors: Vec::new(),
}
}
}
impl From<Presence> for stanza::client::presence::Presence {
fn from(value: Presence) -> Self {
match value {
Presence::Online(online) => online.into(),
Presence::Offline(offline) => offline.into(),
}
}
}

View File

@ -6,3 +6,4 @@ edition = "2021"
[dependencies] [dependencies]
peanuts = { version = "0.1.0", path = "../../peanuts" } peanuts = { version = "0.1.0", path = "../../peanuts" }
jid = { version = "0.1.0", path = "../jid" } jid = { version = "0.1.0", path = "../jid" }
thiserror = "2.0.11"

View File

@ -13,8 +13,8 @@ pub struct Bind {
impl FromElement for Bind { impl FromElement for Bind {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("bind"); element.check_name("bind")?;
element.check_name(XMLNS); element.check_name(XMLNS)?;
let r#type = element.pop_child_opt()?; let r#type = element.pop_child_opt()?;
@ -61,8 +61,8 @@ pub struct FullJidType(pub JID);
impl FromElement for FullJidType { impl FromElement for FullJidType {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("jid"); element.check_name("jid")?;
element.check_namespace(XMLNS); element.check_namespace(XMLNS)?;
let jid = element.pop_value()?; let jid = element.pop_value()?;

View File

@ -1,14 +1,16 @@
use std::fmt::Display;
use std::str::FromStr; use std::str::FromStr;
use peanuts::element::{FromElement, IntoElement}; use peanuts::element::{FromElement, IntoElement};
use peanuts::{DeserializeError, Element}; use peanuts::{DeserializeError, Element};
use thiserror::Error;
use crate::stanza_error::Error as StanzaError; use crate::stanza_error::Error as StanzaError;
use crate::stanza_error::Text; use crate::stanza_error::Text;
use super::XMLNS; use super::XMLNS;
#[derive(Clone, Debug)] #[derive(Clone, Debug, Error)]
pub struct Error { pub struct Error {
by: Option<String>, by: Option<String>,
r#type: ErrorType, r#type: ErrorType,
@ -17,6 +19,22 @@ pub struct Error {
text: Option<Text>, text: Option<Text>,
} }
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}, {}", self.r#type, self.error)?;
if let Some(text) = &self.text {
if let Some(text) = &text.text {
write!(f, ": {}", text)?;
}
}
if let Some(by) = &self.by {
write!(f, " (error returned by `{}`)", by)?;
}
Ok(())
}
}
impl FromElement for Error { impl FromElement for Error {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("error")?; element.check_name("error")?;
@ -55,6 +73,18 @@ pub enum ErrorType {
Wait, Wait,
} }
impl Display for ErrorType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorType::Auth => f.write_str("auth"),
ErrorType::Cancel => f.write_str("cancel"),
ErrorType::Continue => f.write_str("continue"),
ErrorType::Modify => f.write_str("modify"),
ErrorType::Wait => f.write_str("wait"),
}
}
}
impl FromStr for ErrorType { impl FromStr for ErrorType {
type Err = DeserializeError; type Err = DeserializeError;
@ -69,15 +99,3 @@ impl FromStr for ErrorType {
} }
} }
} }
impl ToString for ErrorType {
fn to_string(&self) -> String {
match self {
ErrorType::Auth => "auth".to_string(),
ErrorType::Cancel => "cancel".to_string(),
ErrorType::Continue => "continue".to_string(),
ErrorType::Modify => "modify".to_string(),
ErrorType::Wait => "wait".to_string(),
}
}
}

View File

@ -15,7 +15,7 @@ use crate::{
use super::XMLNS; use super::XMLNS;
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Iq { pub struct Iq {
pub from: Option<JID>, pub from: Option<JID>,
pub id: String, pub id: String,

View File

@ -8,7 +8,7 @@ use peanuts::{
use super::XMLNS; use super::XMLNS;
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Message { pub struct Message {
pub from: Option<JID>, pub from: Option<JID>,
pub id: Option<String>, pub id: Option<String>,

View File

@ -1,7 +1,7 @@
use iq::Iq; use iq::Iq;
use message::Message; use message::Message;
use peanuts::{ use peanuts::{
element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement}, element::{Content, ContentBuilder, FromContent, FromElement, IntoContent},
DeserializeError, DeserializeError,
}; };
use presence::Presence; use presence::Presence;
@ -16,7 +16,7 @@ pub mod presence;
pub const XMLNS: &str = "jabber:client"; pub const XMLNS: &str = "jabber:client";
/// TODO: End tag /// TODO: End tag
#[derive(Debug)] #[derive(Debug, Clone)]
pub enum Stanza { pub enum Stanza {
Message(Message), Message(Message),
Presence(Presence), Presence(Presence),

View File

@ -8,7 +8,7 @@ use peanuts::{
use super::{error::Error, XMLNS}; use super::{error::Error, XMLNS};
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Presence { pub struct Presence {
pub from: Option<JID>, pub from: Option<JID>,
pub id: Option<String>, pub id: Option<String>,

View File

@ -1,9 +1,10 @@
use std::ops::Deref; use std::{fmt::Display, ops::Deref};
use peanuts::{ use peanuts::{
element::{FromElement, IntoElement}, element::{FromElement, IntoElement},
DeserializeError, Element, DeserializeError, Element,
}; };
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
@ -168,12 +169,48 @@ impl IntoElement for Response {
} }
} }
#[derive(Debug)] #[derive(Error, Debug, Clone)]
pub struct Failure { pub struct Failure {
r#type: Option<FailureType>, r#type: Option<FailureType>,
text: Option<Text>, text: Option<Text>,
} }
impl Display for Failure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut had_type = false;
let mut had_text = false;
if let Some(r#type) = &self.r#type {
had_type = true;
match r#type {
FailureType::Aborted => f.write_str("aborted"),
FailureType::AccountDisabled => f.write_str("account disabled"),
FailureType::CredentialsExpired => f.write_str("credentials expired"),
FailureType::EncryptionRequired => f.write_str("encryption required"),
FailureType::IncorrectEncoding => f.write_str("incorrect encoding"),
FailureType::InvalidAuthzid => f.write_str("invalid authzid"),
FailureType::InvalidMechanism => f.write_str("invalid mechanism"),
FailureType::MalformedRequest => f.write_str("malformed request"),
FailureType::MechanismTooWeak => f.write_str("mechanism too weak"),
FailureType::NotAuthorized => f.write_str("not authorized"),
FailureType::TemporaryAuthFailure => f.write_str("temporary auth failure"),
}?;
}
if let Some(text) = &self.text {
if let Some(text) = &text.text {
if had_type {
f.write_str(": ")?;
}
f.write_str(text)?;
had_text = true;
}
}
if !had_type && !had_text {
f.write_str("failure")?;
}
Ok(())
}
}
impl FromElement for Failure { impl FromElement for Failure {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("failure")?; element.check_name("failure")?;
@ -186,18 +223,29 @@ impl FromElement for Failure {
} }
} }
#[derive(Debug)] #[derive(Error, Debug, Clone)]
pub enum FailureType { pub enum FailureType {
#[error("aborted")]
Aborted, Aborted,
#[error("account disabled")]
AccountDisabled, AccountDisabled,
#[error("credentials expired")]
CredentialsExpired, CredentialsExpired,
#[error("encryption required")]
EncryptionRequired, EncryptionRequired,
#[error("incorrect encoding")]
IncorrectEncoding, IncorrectEncoding,
#[error("invalid authzid")]
InvalidAuthzid, InvalidAuthzid,
#[error("invalid mechanism")]
InvalidMechanism, InvalidMechanism,
#[error("malformed request")]
MalformedRequest, MalformedRequest,
#[error("mechanism too weak")]
MechanismTooWeak, MechanismTooWeak,
#[error("not authorized")]
NotAuthorized, NotAuthorized,
#[error("temporary auth failure")]
TemporaryAuthFailure, TemporaryAuthFailure,
} }
@ -220,8 +268,9 @@ impl FromElement for FailureType {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Text { pub struct Text {
#[allow(dead_code)]
lang: Option<String>, lang: Option<String>,
text: Option<String>, text: Option<String>,
} }

View File

@ -4,32 +4,55 @@ use peanuts::{
element::{FromElement, IntoElement}, element::{FromElement, IntoElement},
Element, XML_NS, Element, XML_NS,
}; };
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-stanzas"; pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-stanzas";
#[derive(Clone, Debug)] #[derive(Error, Clone, Debug)]
pub enum Error { pub enum Error {
#[error("bad request")]
BadRequest, BadRequest,
#[error("conflict")]
Conflict, Conflict,
#[error("feature not implemented")]
FeatureNotImplemented, FeatureNotImplemented,
#[error("forbidden")]
Forbidden, Forbidden,
#[error("gone: {0:?}")]
Gone(Option<String>), Gone(Option<String>),
#[error("internal server error")]
InternalServerError, InternalServerError,
#[error("item not found")]
ItemNotFound, ItemNotFound,
JidMalformed, #[error("JID malformed")]
JIDMalformed,
#[error("not acceptable")]
NotAcceptable, NotAcceptable,
#[error("not allowed")]
NotAllowed, NotAllowed,
#[error("not authorized")]
NotAuthorized, NotAuthorized,
#[error("policy violation")]
PolicyViolation, PolicyViolation,
#[error("recipient unavailable")]
RecipientUnavailable, RecipientUnavailable,
#[error("redirect: {0:?}")]
Redirect(Option<String>), Redirect(Option<String>),
#[error("registration required")]
RegistrationRequired, RegistrationRequired,
#[error("remote server not found")]
RemoteServerNotFound, RemoteServerNotFound,
#[error("remote server timeout")]
RemoteServerTimeout, RemoteServerTimeout,
#[error("resource constraint")]
ResourceConstraint, ResourceConstraint,
#[error("service unavailable")]
ServiceUnavailable, ServiceUnavailable,
#[error("subscription required")]
SubscriptionRequired, SubscriptionRequired,
#[error("undefined condition")]
UndefinedCondition, UndefinedCondition,
#[error("unexpected request")]
UnexpectedRequest, UnexpectedRequest,
} }
@ -44,7 +67,7 @@ impl FromElement for Error {
(Some(XMLNS), "gone") => return Ok(Error::Gone(element.pop_value_opt()?)), (Some(XMLNS), "gone") => return Ok(Error::Gone(element.pop_value_opt()?)),
(Some(XMLNS), "internal-server-error") => error = Error::InternalServerError, (Some(XMLNS), "internal-server-error") => error = Error::InternalServerError,
(Some(XMLNS), "item-not-found") => error = Error::ItemNotFound, (Some(XMLNS), "item-not-found") => error = Error::ItemNotFound,
(Some(XMLNS), "jid-malformed") => error = Error::JidMalformed, (Some(XMLNS), "jid-malformed") => error = Error::JIDMalformed,
(Some(XMLNS), "not-acceptable") => error = Error::NotAcceptable, (Some(XMLNS), "not-acceptable") => error = Error::NotAcceptable,
(Some(XMLNS), "not-allowed") => error = Error::NotAllowed, (Some(XMLNS), "not-allowed") => error = Error::NotAllowed,
(Some(XMLNS), "not-authorized") => error = Error::NotAuthorized, (Some(XMLNS), "not-authorized") => error = Error::NotAuthorized,
@ -78,7 +101,7 @@ impl IntoElement for Error {
Error::Gone(r) => Element::builder("gone", Some(XMLNS)).push_text_opt(r.clone()), Error::Gone(r) => Element::builder("gone", Some(XMLNS)).push_text_opt(r.clone()),
Error::InternalServerError => Element::builder("internal-server-error", Some(XMLNS)), Error::InternalServerError => Element::builder("internal-server-error", Some(XMLNS)),
Error::ItemNotFound => Element::builder("item-not-found", Some(XMLNS)), Error::ItemNotFound => Element::builder("item-not-found", Some(XMLNS)),
Error::JidMalformed => Element::builder("jid-malformed", Some(XMLNS)), Error::JIDMalformed => Element::builder("jid-malformed", Some(XMLNS)),
Error::NotAcceptable => Element::builder("not-acceptable", Some(XMLNS)), Error::NotAcceptable => Element::builder("not-acceptable", Some(XMLNS)),
Error::NotAllowed => Element::builder("not-allowed", Some(XMLNS)), Error::NotAllowed => Element::builder("not-allowed", Some(XMLNS)),
Error::NotAuthorized => Element::builder("not-authorized", Some(XMLNS)), Error::NotAuthorized => Element::builder("not-authorized", Some(XMLNS)),
@ -102,7 +125,7 @@ impl IntoElement for Error {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Text { pub struct Text {
lang: Option<String>, lang: Option<String>,
text: Option<String>, pub text: Option<String>,
} }
impl FromElement for Text { impl FromElement for Text {

View File

@ -1,7 +1,5 @@
use std::collections::{HashMap, HashSet};
use peanuts::{ use peanuts::{
element::{Content, FromElement, IntoElement, Name, NamespaceDeclaration}, element::{FromElement, IntoElement},
Element, Element,
}; };

View File

@ -1,15 +1,16 @@
use std::collections::{HashMap, HashSet}; use std::fmt::Display;
use jid::JID; use jid::JID;
use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration}; use peanuts::element::{ElementBuilder, FromElement, IntoElement};
use peanuts::{element::Name, Element}; use peanuts::Element;
use thiserror::Error;
use crate::bind; use crate::bind;
use super::client;
use super::sasl::{self, Mechanisms}; use super::sasl::{self, Mechanisms};
use super::starttls::{self, StartTls}; use super::starttls::{self, StartTls};
use super::stream_error::{Error as StreamError, Text}; use super::stream_error::{Error as StreamError, Text};
use super::{client, stream_error};
pub const XMLNS: &str = "http://etherx.jabber.org/streams"; pub const XMLNS: &str = "http://etherx.jabber.org/streams";
@ -178,12 +179,24 @@ impl FromElement for Feature {
} }
} }
#[derive(Debug)] #[derive(Error, Debug, Clone)]
pub struct Error { pub struct Error {
error: StreamError, error: StreamError,
text: Option<Text>, text: Option<Text>,
} }
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)?;
if let Some(text) = &self.text {
if let Some(text) = &text.text {
write!(f, ": {}", text)?;
}
}
Ok(())
}
}
impl FromElement for Error { impl FromElement for Error {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("error")?; element.check_name("error")?;

View File

@ -2,35 +2,61 @@ use peanuts::{
element::{FromElement, IntoElement}, element::{FromElement, IntoElement},
DeserializeError, Element, XML_NS, DeserializeError, Element, XML_NS,
}; };
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams"; pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams";
#[derive(Clone, Debug)] #[derive(Error, Clone, Debug)]
pub enum Error { pub enum Error {
#[error("bad format")]
BadFormat, BadFormat,
#[error("bad namespace prefix")]
BadNamespacePrefix, BadNamespacePrefix,
#[error("conflict")]
Conflict, Conflict,
#[error("connection timeout")]
ConnectionTimeout, ConnectionTimeout,
#[error("host gone")]
HostGone, HostGone,
#[error("host unknown")]
HostUnknown, HostUnknown,
#[error("improper addressing")]
ImproperAddressing, ImproperAddressing,
#[error("internal server error")]
InternalServerError, InternalServerError,
#[error("invalid from")]
InvalidFrom, InvalidFrom,
#[error("invalid id")]
InvalidId, InvalidId,
#[error("invalid namespace")]
InvalidNamespace, InvalidNamespace,
#[error("invalid xml")]
InvalidXml, InvalidXml,
#[error("not authorized")]
NotAuthorized, NotAuthorized,
#[error("not well formed")]
NotWellFormed, NotWellFormed,
#[error("policy violation")]
PolicyViolation, PolicyViolation,
#[error("remote connection failed")]
RemoteConnectionFailed, RemoteConnectionFailed,
#[error("reset")]
Reset, Reset,
#[error("resource constraint")]
ResourceConstraint, ResourceConstraint,
#[error("restricted xml")]
RestrictedXml, RestrictedXml,
#[error("see other host: {0:?}")]
SeeOtherHost(Option<String>), SeeOtherHost(Option<String>),
#[error("system shutdown")]
SystemShutdown, SystemShutdown,
#[error("undefined condition")]
UndefinedCondition, UndefinedCondition,
#[error("unsupported encoding")]
UnsupportedEncoding, UnsupportedEncoding,
#[error("unsupported stanza type")]
UnsupportedStanzaType, UnsupportedStanzaType,
#[error("unsupported version")]
UnsupportedVersion, UnsupportedVersion,
} }
@ -112,7 +138,7 @@ impl IntoElement for Error {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Text { pub struct Text {
text: Option<String>, pub text: Option<String>,
lang: Option<String>, lang: Option<String>,
} }