Compare commits

..

No commits in common. "861db1197db6d3f9ee0ce3f6e3434666da05a95c" and "7dc1b1f35d02485b75c0373701dd09cec874ac4d" have entirely different histories.

30 changed files with 403 additions and 673 deletions

View File

@ -30,7 +30,6 @@ futures = "0.3.31"
take_mut = "0.2.2" take_mut = "0.2.2"
pin-project-lite = "0.2.15" pin-project-lite = "0.2.15"
pin-project = "1.1.7" pin-project = "1.1.7"
thiserror = "2.0.11"
[dev-dependencies] [dev-dependencies]
test-log = { version = "0.2", features = ["trace"] } test-log = { version = "0.2", features = ["trace"] }

View File

@ -1,17 +1,29 @@
use std::{
borrow::Borrow,
future::Future,
pin::pin,
sync::Arc,
task::{ready, Poll},
};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig; use rsasl::config::SASLConfig;
use stanza::{ use stanza::{
client::Stanza,
sasl::Mechanisms, sasl::Mechanisms,
stream::{Feature, Features}, stream::{Feature, Features},
}; };
use tokio::sync::Mutex;
use crate::{ use crate::{
connection::{Tls, Unencrypted}, connection::{Tls, Unencrypted},
jabber_stream::bound_stream::BoundJabberStream, jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream},
Connection, Error, JabberStream, Result, JID, Connection, Error, JabberStream, Result, JID,
}; };
pub async fn connect_and_login( pub async fn connect_and_login(
jid: &mut JID, mut jid: &mut JID,
password: impl AsRef<str>, password: impl AsRef<str>,
server: &mut String, server: &mut String,
) -> Result<BoundJabberStream<Tls>> { ) -> Result<BoundJabberStream<Tls>> {
@ -19,8 +31,7 @@ pub async fn connect_and_login(
None, None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?, jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.as_ref().to_string(), password.as_ref().to_string(),
) )?;
.map_err(|e| Error::SASL(e.into()))?;
let mut conn_state = Connecting::start(&server).await?; let mut conn_state = Connecting::start(&server).await?;
loop { loop {
match conn_state { match conn_state {
@ -109,8 +120,9 @@ pub enum InsecureConnecting {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration; use std::{sync::Arc, time::Duration};
use futures::{SinkExt, StreamExt};
use jid::JID; use jid::JID;
use stanza::{ use stanza::{
client::{ client::{
@ -120,7 +132,7 @@ mod tests {
xep_0199::Ping, xep_0199::Ping,
}; };
use test_log::test; use test_log::test;
use tokio::time::sleep; use tokio::{sync::Mutex, time::sleep};
use tracing::info; use tracing::info;
use super::connect_and_login; use super::connect_and_login;
@ -128,7 +140,7 @@ mod tests {
#[test(tokio::test)] #[test(tokio::test)]
async fn login() { async fn login() {
let mut jid: JID = "test@blos.sm".try_into().unwrap(); let mut jid: JID = "test@blos.sm".try_into().unwrap();
let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string()) let client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
.await .await
.unwrap(); .unwrap();
sleep(Duration::from_secs(5)).await sleep(Duration::from_secs(5)).await

View File

@ -1,7 +1,9 @@
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::str; use std::str;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use rsasl::config::SASLConfig;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector; use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls // TODO: use rustls

View File

@ -1,58 +1,87 @@
use std::str::Utf8Error; use std::str::Utf8Error;
use std::sync::Arc;
use jid::ParseError; use jid::ParseError;
use rsasl::mechname::MechanismNameError; use rsasl::mechname::MechanismNameError;
use stanza::client::error::Error as ClientError; use stanza::client::error::Error as ClientError;
use stanza::sasl::Failure; use stanza::sasl::Failure;
use stanza::stream::Error as StreamError; use stanza::stream::Error as StreamError;
use thiserror::Error; use tokio::task::JoinError;
#[derive(Error, Debug, Clone)] #[derive(Debug)]
pub enum Error { pub enum Error {
#[error("connection")]
Connection, Connection,
#[error("utf8 decode: {0}")] Utf8Decode,
Utf8Decode(#[from] Utf8Error),
#[error("negotiation")]
Negotiation, Negotiation,
#[error("tls required")]
TlsRequired, TlsRequired,
#[error("already connected with tls")]
AlreadyTls, AlreadyTls,
// TODO: specify unsupported feature
#[error("unsupported feature")]
Unsupported, Unsupported,
#[error("jid missing localpart")]
NoLocalpart, NoLocalpart,
#[error("received unexpected element: {0:?}")] AlreadyConnecting,
StreamClosed,
UnexpectedElement(peanuts::Element), UnexpectedElement(peanuts::Element),
#[error("xml error: {0}")] XML(peanuts::Error),
XML(#[from] peanuts::Error), Deserialization(peanuts::DeserializeError),
#[error("sasl error: {0}")] SASL(SASLError),
SASL(#[from] SASLError), JID(ParseError),
#[error("jid error: {0}")] Authentication(Failure),
JID(#[from] ParseError), ClientError(ClientError),
#[error("client stanza error: {0}")] StreamError(StreamError),
ClientError(#[from] ClientError),
#[error("stream error: {0}")]
StreamError(#[from] StreamError),
#[error("error missing")]
MissingError, MissingError,
Disconnected,
Connecting,
JoinError(JoinError),
} }
#[derive(Error, Debug, Clone)] #[derive(Debug)]
pub enum SASLError { pub enum SASLError {
#[error("sasl error: {0}")] SASL(rsasl::prelude::SASLError),
SASL(Arc<rsasl::prelude::SASLError>), MechanismName(MechanismNameError),
#[error("mechanism error: {0}")]
MechanismName(#[from] MechanismNameError),
#[error("authentication failure: {0}")]
Authentication(#[from] Failure),
} }
impl From<rsasl::prelude::SASLError> for SASLError { impl From<rsasl::prelude::SASLError> for Error {
fn from(e: rsasl::prelude::SASLError) -> Self { fn from(e: rsasl::prelude::SASLError) -> Self {
Self::SASL(Arc::new(e)) Self::SASL(SASLError::SASL(e))
}
}
impl From<JoinError> for Error {
fn from(e: JoinError) -> Self {
Self::JoinError(e)
}
}
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)
}
}
impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e))
}
}
impl From<SASLError> for Error {
fn from(e: SASLError) -> Self {
Self::SASL(e)
}
}
impl From<Utf8Error> for Error {
fn from(_e: Utf8Error) -> Self {
Self::Utf8Decode
}
}
impl From<peanuts::Error> for Error {
fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
impl From<ParseError> for Error {
fn from(e: ParseError) -> Self {
Self::JID(e)
} }
} }

View File

@ -1,8 +1,10 @@
use std::pin::pin;
use std::str::{self, FromStr}; use std::str::{self, FromStr};
use std::sync::Arc; use std::sync::Arc;
use futures::{sink, stream, StreamExt};
use jid::JID; use jid::JID;
use peanuts::element::IntoElement; use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer}; use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use stanza::bind::{Bind, BindType, FullJidType, ResourceType}; use stanza::bind::{Bind, BindType, FullJidType, ResourceType};
@ -133,16 +135,13 @@ where
let sasl = SASLClient::new(sasl_config); let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new(); let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms { for mechanism in &mechanisms.mechanisms {
offered_mechs offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
.push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)
} }
debug!("{:?}", offered_mechs); debug!("{:?}", offered_mechs);
let mut session = sasl let mut session = sasl.start_suggested(&offered_mechs)?;
.start_suggested(&offered_mechs)
.map_err(|e| Error::SASL(e.into()))?;
let selected_mechanism = session.get_mechname().as_str().to_owned(); let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism); debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>>; let mut data: Option<Vec<u8>> = None;
if !session.are_we_first() { if !session.are_we_first() {
// if not first mention the mechanism then get challenge data // if not first mention the mechanism then get challenge data
@ -177,7 +176,7 @@ where
ServerResponse::Success(success) => { ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec()) data = success.clone().map(|success| success.as_bytes().to_vec())
} }
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())), ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
} }
debug!("we went first"); debug!("we went first");
} }
@ -208,7 +207,7 @@ where
ServerResponse::Success(success) => { ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec()) data = success.clone().map(|success| success.as_bytes().to_vec())
} }
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())), ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
} }
} }
} }
@ -410,7 +409,13 @@ impl std::fmt::Debug for JabberStream<Unencrypted> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration;
use super::*;
use crate::connection::Connection;
use futures::sink;
use test_log::test; use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)] #[test(tokio::test)]
async fn start_stream() { async fn start_stream() {

View File

@ -1,6 +1,9 @@
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use tokio::io::{AsyncRead, AsyncWrite}; use peanuts::{Reader, Writer};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use crate::Error;
use super::{JabberReader, JabberStream, JabberWriter}; use super::{JabberReader, JabberStream, JabberWriter};

View File

@ -8,6 +8,7 @@ pub mod error;
pub mod jabber_stream; pub mod jabber_stream;
pub use connection::Connection; pub use connection::Connection;
use connection::Tls;
pub use error::Error; pub use error::Error;
pub use jabber_stream::JabberStream; pub use jabber_stream::JabberStream;
pub use jid::JID; pub use jid::JID;

View File

@ -3,8 +3,5 @@ name = "jid"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[features]
sqlx = ["dep:sqlx"]
[dependencies] [dependencies]
sqlx = { version = "0.8.3", features = ["sqlite"], optional = true } sqlx = { version = "0.8.3", features = ["sqlite"] }

View File

@ -1,9 +1,8 @@
use std::{borrow::Cow, error::Error, fmt::Display, str::FromStr}; use std::{error::Error, fmt::Display, str::FromStr};
#[cfg(feature = "sqlx")]
use sqlx::Sqlite; use sqlx::Sqlite;
#[derive(PartialEq, Debug, Clone, Eq, Hash)] #[derive(PartialEq, Debug, Clone, sqlx::Type, sqlx::Encode, Eq, Hash)]
pub struct JID { pub struct JID {
// TODO: validate localpart (length, char] // TODO: validate localpart (length, char]
pub localpart: Option<String>, pub localpart: Option<String>,
@ -11,36 +10,13 @@ pub struct JID {
pub resourcepart: Option<String>, pub resourcepart: Option<String>,
} }
impl<'a> Into<Cow<'a, str>> for &'a JID { // TODO: feature gate
fn into(self) -> Cow<'a, str> {
let a = self.to_string();
Cow::Owned(a)
}
}
impl Display for JID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(localpart) = &self.localpart {
f.write_str(localpart)?;
f.write_str("@")?;
}
f.write_str(&self.domainpart)?;
if let Some(resourcepart) = &self.resourcepart {
f.write_str("/")?;
f.write_str(resourcepart)?;
}
Ok(())
}
}
#[cfg(feature = "sqlx")]
impl sqlx::Type<Sqlite> for JID { impl sqlx::Type<Sqlite> for JID {
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo { fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
<&str as sqlx::Type<Sqlite>>::type_info() <&str as sqlx::Type<Sqlite>>::type_info()
} }
} }
#[cfg(feature = "sqlx")]
impl sqlx::Decode<'_, Sqlite> for JID { impl sqlx::Decode<'_, Sqlite> for JID {
fn decode( fn decode(
value: <Sqlite as sqlx::Database>::ValueRef<'_>, value: <Sqlite as sqlx::Database>::ValueRef<'_>,
@ -51,7 +27,6 @@ impl sqlx::Decode<'_, Sqlite> for JID {
} }
} }
#[cfg(feature = "sqlx")]
impl sqlx::Encode<'_, Sqlite> for JID { impl sqlx::Encode<'_, Sqlite> for JID {
fn encode_by_ref( fn encode_by_ref(
&self, &self,
@ -62,24 +37,12 @@ impl sqlx::Encode<'_, Sqlite> for JID {
} }
} }
#[derive(Debug, Clone)]
pub enum JIDError { pub enum JIDError {
NoResourcePart, NoResourcePart,
ParseError(ParseError), ParseError(ParseError),
} }
impl Display for JIDError { #[derive(Debug)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JIDError::NoResourcePart => f.write_str("resourcepart missing"),
JIDError::ParseError(parse_error) => parse_error.fmt(f),
}
}
}
impl Error for JIDError {}
#[derive(Debug, Clone)]
pub enum ParseError { pub enum ParseError {
Empty, Empty,
Malformed(String), Malformed(String),
@ -184,6 +147,21 @@ impl TryFrom<&str> for JID {
} }
} }
impl std::fmt::Display for JID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}{}{}",
self.localpart.clone().map(|l| l + "@").unwrap_or_default(),
self.domainpart,
self.resourcepart
.clone()
.map(|r| "/".to_owned() + &r)
.unwrap_or_default()
)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

1
luz/.gitignore vendored
View File

@ -1,2 +1 @@
luz.db luz.db
.sqlx/

View File

@ -7,7 +7,7 @@ edition = "2021"
futures = "0.3.31" futures = "0.3.31"
jabber = { version = "0.1.0", path = "../jabber" } jabber = { version = "0.1.0", path = "../jabber" }
peanuts = { version = "0.1.0", path = "../../peanuts" } peanuts = { version = "0.1.0", path = "../../peanuts" }
jid = { version = "0.1.0", path = "../jid", features = ["sqlx"] } jid = { version = "0.1.0", path = "../jid" }
sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "uuid"] } sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "uuid"] }
stanza = { version = "0.1.0", path = "../stanza" } stanza = { version = "0.1.0", path = "../stanza" }
tokio = "1.42.0" tokio = "1.42.0"
@ -16,4 +16,3 @@ tokio-util = "0.7.13"
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = "0.3.19" tracing-subscriber = "0.3.19"
uuid = { version = "1.13.1", features = ["v4"] } uuid = { version = "1.13.1", features = ["v4"] }
thiserror = "2.0.11"

View File

@ -20,7 +20,7 @@ use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage};
use crate::{ use crate::{
db::Db, db::Db,
error::{Error, ReadError, WriteError}, error::{Error, Reason},
UpdateMessage, UpdateMessage,
}; };
@ -36,7 +36,7 @@ pub struct Supervisor {
tokio::task::JoinSet<()>, tokio::task::JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>, )>,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
writer_handle: WriteControlHandle, writer_handle: WriteControlHandle,
@ -62,7 +62,7 @@ pub enum State {
tokio::task::JoinSet<()>, tokio::task::JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
), ),
), ),
} }
@ -77,7 +77,7 @@ impl Supervisor {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>, )>,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
writer_handle: WriteControlHandle, writer_handle: WriteControlHandle,
@ -180,7 +180,7 @@ impl Supervisor {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_state.close(); write_state.close();
while let Some(msg) = write_state.recv().await { while let Some(msg) = write_state.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection)); let _ = msg.respond_to.send(Err(Reason::LostConnection));
} }
// TODO: is this the correct error? // TODO: is this the correct error?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await; let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -227,9 +227,9 @@ impl Supervisor {
Err(e) => { Err(e) => {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_recv.close(); write_recv.close();
let _ = write_msg.respond_to.send(Err(WriteError::LostConnection)); let _ = write_msg.respond_to.send(Err(Reason::LostConnection));
while let Some(msg) = write_recv.recv().await { while let Some(msg) = write_recv.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection)); let _ = msg.respond_to.send(Err(Reason::LostConnection));
} }
// TODO: is this the correct error to send? // TODO: is this the correct error to send?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await; let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -278,10 +278,10 @@ impl Supervisor {
// if reconnection failure, respond to all current messages with lost connection error. // if reconnection failure, respond to all current messages with lost connection error.
write_receiver.close(); write_receiver.close();
if let Some(msg) = retry_msg { if let Some(msg) = retry_msg {
msg.respond_to.send(Err(WriteError::LostConnection)); msg.respond_to.send(Err(Reason::LostConnection));
} }
while let Some(msg) = write_receiver.recv().await { while let Some(msg) = write_receiver.recv().await {
msg.respond_to.send(Err(WriteError::LostConnection)); msg.respond_to.send(Err(Reason::LostConnection));
} }
// TODO: is this the correct error? // TODO: is this the correct error?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await; let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -342,7 +342,7 @@ impl SupervisorHandle {
on_shutdown: oneshot::Sender<()>, on_shutdown: oneshot::Sender<()>,
jid: Arc<Mutex<JID>>, jid: Arc<Mutex<JID>>,
password: Arc<String>, password: Arc<String>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> (WriteHandle, Self) { ) -> (WriteHandle, Self) {
let (command_sender, command_receiver) = mpsc::channel(20); let (command_sender, command_receiver) = mpsc::channel(20);
let (writer_error_sender, writer_error_receiver) = oneshot::channel(); let (writer_error_sender, writer_error_receiver) = oneshot::channel();

View File

@ -18,13 +18,16 @@ use uuid::Uuid;
use crate::{ use crate::{
chat::{Body, Message}, chat::{Body, Message},
db::Db, db::Db,
error::{Error, IqError, MessageRecvError, PresenceError, ReadError, RosterError}, error::{Error, IqError, PresenceError, Reason, RecvMessageError},
presence::{Offline, Online, Presence, Show}, presence::{Offline, Online, Presence, Show},
roster::Contact, roster::Contact,
UpdateMessage, UpdateMessage,
}; };
use super::{write::WriteHandle, SupervisorCommand}; use super::{
write::{WriteHandle, WriteMessage},
SupervisorCommand,
};
pub struct Read { pub struct Read {
control_receiver: mpsc::Receiver<ReadControl>, control_receiver: mpsc::Receiver<ReadControl>,
@ -35,7 +38,7 @@ pub struct Read {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>, )>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
@ -45,7 +48,7 @@ pub struct Read {
disconnecting: bool, disconnecting: bool,
disconnect_timedout: oneshot::Receiver<()>, disconnect_timedout: oneshot::Receiver<()>,
// TODO: use proper stanza ids // TODO: use proper stanza ids
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
} }
impl Read { impl Read {
@ -58,7 +61,7 @@ impl Read {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>, )>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
@ -66,9 +69,9 @@ impl Read {
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
write_handle: WriteHandle, write_handle: WriteHandle,
tasks: JoinSet<()>, tasks: JoinSet<()>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> Self { ) -> Self {
let (_send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
Self { Self {
control_receiver, control_receiver,
stream, stream,
@ -159,7 +162,7 @@ impl Read {
// when it aborts, must clear iq map no matter what // when it aborts, must clear iq map no matter what
let mut iqs = self.pending_iqs.lock().await; let mut iqs = self.pending_iqs.lock().await;
for (_id, sender) in iqs.drain() { for (_id, sender) in iqs.drain() {
let _ = sender.send(Err(ReadError::LostConnection)); let _ = sender.send(Err(Reason::LostConnection));
} }
} }
} }
@ -175,7 +178,7 @@ async fn handle_stanza(
db: Db, db: Db,
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
write_handle: WriteHandle, write_handle: WriteHandle,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) { ) {
match stanza { match stanza {
Stanza::Message(stanza_message) => { Stanza::Message(stanza_message) => {
@ -204,9 +207,7 @@ async fn handle_stanza(
if let Err(e) = result { if let Err(e) = result {
tracing::error!("messagecreate"); tracing::error!("messagecreate");
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::MessageRecv( .send(UpdateMessage::Error(Error::CacheUpdate(e.into())))
MessageRecvError::MessageHistory(e.into()),
)))
.await; .await;
} }
let _ = update_sender let _ = update_sender
@ -214,8 +215,8 @@ async fn handle_stanza(
.await; .await;
} else { } else {
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::MessageRecv( .send(UpdateMessage::Error(Error::RecvMessage(
MessageRecvError::MissingFrom, RecvMessageError::MissingFrom,
))) )))
.await; .await;
} }
@ -228,16 +229,9 @@ async fn handle_stanza(
stanza::client::presence::PresenceType::Error => { stanza::client::presence::PresenceType::Error => {
// TODO: is there any other information that should go with the error? also MUST have an error, otherwise it's a different error. maybe it shoulnd't be an option. // TODO: is there any other information that should go with the error? also MUST have an error, otherwise it's a different error. maybe it shoulnd't be an option.
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::Presence( .send(UpdateMessage::Error(Error::Presence(PresenceError::Error(
// TODO: ughhhhhhhhhhhhh these stanza errors should probably just have an option, and custom display Reason::Stanza(presence.errors.first().cloned()),
PresenceError::StanzaError( ))))
presence
.errors
.first()
.cloned()
.expect("error MUST have error"),
),
)))
.await; .await;
} }
// should not happen (error to server) // should not happen (error to server)
@ -335,8 +329,8 @@ async fn handle_stanza(
let contact: Contact = item.into(); let contact: Contact = item.into();
if let Err(e) = db.upsert_contact(contact.clone()).await { if let Err(e) = db.upsert_contact(contact.clone()).await {
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::Roster( .send(UpdateMessage::Error(Error::CacheUpdate(
RosterError::Cache(e.into()), e.into(),
))) )))
.await; .await;
} }
@ -387,7 +381,7 @@ pub enum ReadControl {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>, )>,
), ),
} }
@ -420,13 +414,13 @@ impl ReadControlHandle {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>, )>,
db: Db, db: Db,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
jabber_write: WriteHandle, jabber_write: WriteHandle,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> Self { ) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20); let (control_sender, control_receiver) = mpsc::channel(20);
@ -457,14 +451,14 @@ impl ReadControlHandle {
JoinSet<()>, JoinSet<()>,
mpsc::Sender<SupervisorCommand>, mpsc::Sender<SupervisorCommand>,
WriteHandle, WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>, )>,
db: Db, db: Db,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
supervisor_control: mpsc::Sender<SupervisorCommand>, supervisor_control: mpsc::Sender<SupervisorCommand>,
jabber_write: WriteHandle, jabber_write: WriteHandle,
tasks: JoinSet<()>, tasks: JoinSet<()>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> Self { ) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20); let (control_sender, control_receiver) = mpsc::channel(20);

View File

@ -7,7 +7,7 @@ use tokio::{
task::JoinHandle, task::JoinHandle,
}; };
use crate::error::WriteError; use crate::error::{Error, Reason};
// actor that receives jabber stanzas to write, and if there is an error, sends a message back to the supervisor then aborts, so the supervisor can spawn a new stream. // actor that receives jabber stanzas to write, and if there is an error, sends a message back to the supervisor then aborts, so the supervisor can spawn a new stream.
pub struct Write { pub struct Write {
@ -17,10 +17,9 @@ pub struct Write {
on_crash: oneshot::Sender<(WriteMessage, mpsc::Receiver<WriteMessage>)>, on_crash: oneshot::Sender<(WriteMessage, mpsc::Receiver<WriteMessage>)>,
} }
#[derive(Debug)]
pub struct WriteMessage { pub struct WriteMessage {
pub stanza: Stanza, pub stanza: Stanza,
pub respond_to: oneshot::Sender<Result<(), WriteError>>, pub respond_to: oneshot::Sender<Result<(), Reason>>,
} }
pub enum WriteControl { pub enum WriteControl {
@ -85,9 +84,9 @@ impl Write {
Err(e) => match &e { Err(e) => match &e {
peanuts::Error::ReadError(_error) => { peanuts::Error::ReadError(_error) => {
// if connection lost during disconnection, just send lost connection error to the write requests // if connection lost during disconnection, just send lost connection error to the write requests
let _ = msg.respond_to.send(Err(WriteError::LostConnection)); let _ = msg.respond_to.send(Err(Reason::LostConnection));
while let Some(msg) = self.stanza_receiver.recv().await { while let Some(msg) = self.stanza_receiver.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection)); let _ = msg.respond_to.send(Err(Reason::LostConnection));
} }
break; break;
} }
@ -141,16 +140,16 @@ pub struct WriteHandle {
} }
impl WriteHandle { impl WriteHandle {
pub async fn write(&self, stanza: Stanza) -> Result<(), WriteError> { pub async fn write(&self, stanza: Stanza) -> Result<(), Reason> {
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
self.send(WriteMessage { self.send(WriteMessage {
stanza, stanza,
respond_to: send, respond_to: send,
}) })
.await .await
.map_err(|e| WriteError::Actor(e.into()))?; .map_err(|_| Reason::ChannelSend)?;
// TODO: timeout // TODO: timeout
recv.await.map_err(|e| WriteError::Actor(e.into()))? recv.await?
} }
} }

View File

@ -1,158 +1,138 @@
use std::sync::Arc;
use stanza::client::Stanza; use stanza::client::Stanza;
use thiserror::Error; use tokio::sync::oneshot::{self};
use tokio::sync::{mpsc::error::SendError, oneshot::error::RecvError};
#[derive(Debug, Error, Clone)] #[derive(Debug)]
pub enum Error { pub enum Error {
#[error("already connected")]
AlreadyConnected, AlreadyConnected,
// TODO: change to Connecting(ConnectingError) // TODO: change to Connecting(ConnectingError)
#[error("connecting: {0}")] Connection(ConnectionError),
Connecting(#[from] ConnectionError), Presence(PresenceError),
#[error("presence: {0}")] SetStatus(Reason),
Presence(#[from] PresenceError), Roster(Reason),
#[error("set status: {0}")] Stream(stanza::stream::Error),
SetStatus(#[from] StatusError), SendMessage(Reason),
// TODO: have different ones for get/update/set RecvMessage(RecvMessageError),
#[error("roster: {0}")]
Roster(RosterError),
#[error("stream error: {0}")]
Stream(#[from] stanza::stream::Error),
#[error("message send error: {0}")]
MessageSend(MessageSendError),
#[error("message receive error: {0}")]
MessageRecv(MessageRecvError),
#[error("already disconnected")]
AlreadyDisconnected, AlreadyDisconnected,
#[error("lost connection")]
LostConnection, LostConnection,
// TODO: Display for Content // TODO: should all cache update errors include the context?
#[error("received unrecognized/unsupported content: {0:?}")] CacheUpdate(Reason),
UnrecognizedContent(peanuts::element::Content), UnrecognizedContent(peanuts::element::Content),
#[error("iq receive error: {0}")]
Iq(IqError), Iq(IqError),
#[error("disconnected")] Cloned,
Disconnected,
} }
#[derive(Debug, Error, Clone)] // TODO: this is horrifying, maybe just use tracing to forward error events???
pub enum MessageSendError { impl Clone for Error {
#[error("could not add to message history: {0}")] fn clone(&self) -> Self {
MessageHistory(#[from] DatabaseError), Error::Cloned
}
} }
#[derive(Debug, Error, Clone)] #[derive(Debug)]
pub enum PresenceError { pub enum PresenceError {
#[error("unsupported")] Error(Reason),
Unsupported, Unsupported,
#[error("missing from")]
MissingFrom, MissingFrom,
#[error("stanza error: {0}")]
StanzaError(#[from] stanza::client::error::Error),
} }
#[derive(Debug, Error, Clone)] #[derive(Debug)]
// TODO: should probably have all iq query related errors here, including read, write, stanza error, etc.
pub enum IqError { pub enum IqError {
#[error("no iq with id matching `{0}`")]
NoMatchingId(String), NoMatchingId(String),
} }
#[derive(Debug, Error, Clone)] #[derive(Debug)]
pub enum MessageRecvError { pub enum RecvMessageError {
#[error("could not add to message history: {0}")]
MessageHistory(#[from] DatabaseError),
#[error("missing from")]
MissingFrom, MissingFrom,
} }
#[derive(Debug, Clone, Error)] #[derive(Debug, Clone)]
pub enum ConnectionError { pub enum ConnectionError {
#[error("connection failed: {0}")] ConnectionFailed(Reason),
ConnectionFailed(#[from] jabber::Error), RosterRetreival(Reason),
#[error("failed roster retreival: {0}")] SendPresence(Reason),
RosterRetreival(#[from] RosterError), NoCachedStatus(Reason),
#[error("failed to send available presence: {0}")]
SendPresence(#[from] WriteError),
#[error("cached status: {0}")]
StatusCacheError(#[from] DatabaseError),
} }
#[derive(Debug, Error, Clone)] #[derive(Debug)]
pub enum RosterError { pub struct RosterError(pub Reason);
#[error("cache: {0}")]
Cache(#[from] DatabaseError),
#[error("stream write: {0}")]
Write(#[from] WriteError),
// TODO: display for stanza, to show as xml, same for read error types.
#[error("unexpected reply: {0:?}")]
UnexpectedStanza(Stanza),
#[error("stream read: {0}")]
Read(#[from] ReadError),
#[error("stanza error: {0}")]
StanzaError(#[from] stanza::client::error::Error),
}
#[derive(Debug, Error, Clone)] impl From<RosterError> for Error {
#[error("database error: {0}")] fn from(e: RosterError) -> Self {
pub struct DatabaseError(Arc<sqlx::Error>); Self::Roster(e.0)
impl From<sqlx::Error> for DatabaseError {
fn from(e: sqlx::Error) -> Self {
Self(Arc::new(e))
} }
} }
#[derive(Debug, Error, Clone)] impl From<RosterError> for ConnectionError {
pub enum StatusError { fn from(e: RosterError) -> Self {
#[error("cache: {0}")] Self::RosterRetreival(e.0)
Cache(#[from] DatabaseError), }
#[error("stream write: {0}")]
Write(#[from] WriteError),
} }
#[derive(Debug, Error, Clone)] pub struct StatusError(pub Reason);
pub enum WriteError {
#[error("xml: {0}")] impl From<StatusError> for Error {
XML(#[from] peanuts::Error), fn from(e: StatusError) -> Self {
#[error("lost connection")] Error::SetStatus(e.0)
LostConnection, }
// TODO: should this be in writeerror or separate?
#[error("actor: {0}")]
Actor(#[from] ActorError),
#[error("disconnected")]
Disconnected,
} }
// TODO: separate peanuts read and write error? impl From<StatusError> for ConnectionError {
#[derive(Debug, Error, Clone)] fn from(e: StatusError) -> Self {
pub enum ReadError { Self::SendPresence(e.0)
#[error("xml: {0}")] }
XML(#[from] peanuts::Error),
#[error("lost connection")]
LostConnection,
} }
#[derive(Debug, Error, Clone)] #[derive(Debug)]
pub enum ActorError { pub enum Reason {
#[error("receive timed out")] // TODO: organisastion of error into internal error thing
Timeout, Timeout,
#[error("could not send message to actor, channel closed")] Stream(stanza::stream_error::Error),
Send, Stanza(Option<stanza::client::error::Error>),
#[error("could not receive message from actor, channel closed")] Jabber(jabber::Error),
Receive, XML(peanuts::Error),
SQL(sqlx::Error),
// JID(jid::ParseError),
LostConnection,
OneshotRecv(oneshot::error::RecvError),
UnexpectedStanza(Stanza),
Disconnected,
ChannelSend,
Cloned,
} }
impl<T> From<SendError<T>> for ActorError { // TODO: same here
fn from(_e: SendError<T>) -> Self { impl Clone for Reason {
Self::Send fn clone(&self) -> Self {
Reason::Cloned
} }
} }
impl From<RecvError> for ActorError { impl From<oneshot::error::RecvError> for Reason {
fn from(_e: RecvError) -> Self { fn from(e: oneshot::error::RecvError) -> Reason {
Self::Receive Self::OneshotRecv(e)
}
}
impl From<peanuts::Error> for Reason {
fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
// impl From<jid::ParseError> for Reason {
// fn from(e: jid::ParseError) -> Self {
// Self::JID(e)
// }
// }
impl From<sqlx::Error> for Reason {
fn from(e: sqlx::Error) -> Self {
Self::SQL(e)
}
}
impl From<jabber::Error> for Reason {
fn from(e: jabber::Error) -> Self {
Self::Jabber(e)
} }
} }

View File

@ -7,8 +7,7 @@ use std::{
use chat::{Body, Chat, Message}; use chat::{Body, Chat, Message};
use connection::{write::WriteMessage, SupervisorSender}; use connection::{write::WriteMessage, SupervisorSender};
use db::Db; use db::Db;
use error::{ConnectionError, DatabaseError, ReadError, RosterError, StatusError, WriteError}; use error::{ConnectionError, Reason, RosterError, StatusError};
use futures::{future::Fuse, FutureExt};
use jabber::JID; use jabber::JID;
use presence::{Offline, Online, Presence}; use presence::{Offline, Online, Presence};
use roster::{Contact, ContactUpdate}; use roster::{Contact, ContactUpdate};
@ -21,7 +20,7 @@ use tokio::{
sync::{mpsc, oneshot, Mutex}, sync::{mpsc, oneshot, Mutex},
task::JoinSet, task::JoinSet,
}; };
use tracing::{debug, info}; use tracing::{debug, info, Instrument};
use user::User; use user::User;
use uuid::Uuid; use uuid::Uuid;
@ -44,11 +43,11 @@ pub struct Luz {
// TODO: use a dyn passwordprovider trait to avoid storing password in memory // TODO: use a dyn passwordprovider trait to avoid storing password in memory
password: Arc<String>, password: Arc<String>,
connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>, connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
db: Db, db: Db,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
/// if connection was shut down due to e.g. server shutdown, supervisor must be able to mark client as disconnected /// if connection was shut down due to e.g. server shutdown, supervisor must be able to mark client as disconnected
connection_supervisor_shutdown: Fuse<oneshot::Receiver<()>>, connection_supervisor_shutdown: oneshot::Receiver<()>,
// TODO: will need to have an auto reconnect state as well (e.g. in case server shut down, to try and reconnect later) // TODO: will need to have an auto reconnect state as well (e.g. in case server shut down, to try and reconnect later)
// TODO: will grow forever at this point, maybe not required as tasks will naturally shut down anyway? // TODO: will grow forever at this point, maybe not required as tasks will naturally shut down anyway?
tasks: JoinSet<()>, tasks: JoinSet<()>,
@ -61,7 +60,7 @@ impl Luz {
jid: Arc<Mutex<JID>>, jid: Arc<Mutex<JID>>,
password: String, password: String,
connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>, connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>,
connection_supervisor_shutdown: Fuse<oneshot::Receiver<()>>, connection_supervisor_shutdown: oneshot::Receiver<()>,
db: SqlitePool, db: SqlitePool,
sender: mpsc::Sender<UpdateMessage>, sender: mpsc::Sender<UpdateMessage>,
) -> Self { ) -> Self {
@ -83,8 +82,9 @@ impl Luz {
loop { loop {
let msg = tokio::select! { let msg = tokio::select! {
// this is okay, as when created the supervisor (and connection) doesn't exist, but a bit messy // this is okay, as when created the supervisor (and connection) doesn't exist, but a bit messy
// THIS IS NOT OKAY LOLLLL - apparently fusing is the best option??? // THIS IS NOT OKAY LOLLLL
_ = &mut self.connection_supervisor_shutdown => { _ = &mut self.connection_supervisor_shutdown => {
info!("got this");
*self.connected.lock().await = None; *self.connected.lock().await = None;
continue; continue;
} }
@ -130,7 +130,6 @@ impl Luz {
self.password.clone(), self.password.clone(),
self.pending_iqs.clone(), self.pending_iqs.clone(),
); );
let shutdown_recv = shutdown_recv.fuse();
self.connection_supervisor_shutdown = shutdown_recv; self.connection_supervisor_shutdown = shutdown_recv;
// TODO: get roster and send initial presence // TODO: get roster and send initial presence
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
@ -159,8 +158,8 @@ impl Luz {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error( .send(UpdateMessage::Error(
Error::Connecting( Error::Connection(
ConnectionError::StatusCacheError( ConnectionError::NoCachedStatus(
e.into(), e.into(),
), ),
), ),
@ -170,20 +169,16 @@ impl Luz {
} }
}; };
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
CommandMessage::SendPresence( CommandMessage::SetStatus(online.clone(), send)
None, .handle_online(
Presence::Online(online.clone()), writer.clone(),
send, supervisor.sender(),
) self.jid.clone(),
.handle_online( self.db.clone(),
writer.clone(), self.sender.clone(),
supervisor.sender(), self.pending_iqs.clone(),
self.jid.clone(), )
self.db.clone(), .await;
self.sender.clone(),
self.pending_iqs.clone(),
)
.await;
let set_status = recv.await; let set_status = recv.await;
match set_status { match set_status {
Ok(s) => match s { Ok(s) => match s {
@ -202,13 +197,13 @@ impl Luz {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error( .send(UpdateMessage::Error(
Error::Connecting(e.into()), Error::Connection(e.into()),
)) ))
.await; .await;
} }
}, },
Err(e) => { Err(e) => {
let _ = self.sender.send(UpdateMessage::Error(Error::Connecting(ConnectionError::SendPresence(WriteError::Actor(e.into()))))).await; let _ = self.sender.send(UpdateMessage::Error(Error::Connection(ConnectionError::SendPresence(e.into())))).await;
} }
} }
} }
@ -216,7 +211,7 @@ impl Luz {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error( .send(UpdateMessage::Error(
Error::Connecting(e.into()), Error::Connection(e.into()),
)) ))
.await; .await;
} }
@ -225,12 +220,8 @@ impl Luz {
Err(e) => { Err(e) => {
let _ = self let _ = self
.sender .sender
.send(UpdateMessage::Error(Error::Connecting( .send(UpdateMessage::Error(Error::Connection(
ConnectionError::RosterRetreival( ConnectionError::RosterRetreival(e.into()),
RosterError::Write(WriteError::Actor(
e.into(),
)),
),
))) )))
.await; .await;
} }
@ -238,7 +229,7 @@ impl Luz {
} }
Err(e) => { Err(e) => {
let _ = let _ =
self.sender.send(UpdateMessage::Error(Error::Connecting( self.sender.send(UpdateMessage::Error(Error::Connection(
ConnectionError::ConnectionFailed(e.into()), ConnectionError::ConnectionFailed(e.into()),
))); )));
} }
@ -246,7 +237,7 @@ impl Luz {
} }
}; };
} }
CommandMessage::Disconnect(offline) => { CommandMessage::Disconnect(_offline) => {
match self.connected.lock().await.as_mut() { match self.connected.lock().await.as_mut() {
None => { None => {
let _ = self let _ = self
@ -256,19 +247,15 @@ impl Luz {
} }
mut c => { mut c => {
// TODO: send unavailable presence // TODO: send unavailable presence
if let Some((write_handle, supervisor_handle)) = c.take() { if let Some((_write_handle, supervisor_handle)) = c.take() {
let offline_presence: stanza::client::presence::Presence =
offline.clone().into();
let stanza = Stanza::Presence(offline_presence);
// TODO: timeout and error check
write_handle.write(stanza).await;
let _ = supervisor_handle.send(SupervisorCommand::Disconnect).await; let _ = supervisor_handle.send(SupervisorCommand::Disconnect).await;
let _ = self.sender.send(UpdateMessage::Offline(offline)).await; c = None;
} else { } else {
unreachable!() unreachable!()
}; };
} }
} }
info!("lock released")
} }
_ => { _ => {
match self.connected.lock().await.as_ref() { match self.connected.lock().await.as_ref() {
@ -294,7 +281,7 @@ impl Luz {
impl CommandMessage { impl CommandMessage {
pub async fn handle_offline( pub async fn handle_offline(
self, mut self,
jid: Arc<Mutex<JID>>, jid: Arc<Mutex<JID>>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
@ -309,7 +296,7 @@ impl CommandMessage {
let _ = sender.send(Ok(roster)); let _ = sender.send(Ok(roster));
} }
Err(e) => { Err(e) => {
let _ = sender.send(Err(RosterError::Cache(e.into()))); let _ = sender.send(Err(RosterError(e.into())));
} }
} }
} }
@ -339,48 +326,45 @@ impl CommandMessage {
} }
// TODO: offline queue to modify roster // TODO: offline queue to modify roster
CommandMessage::AddContact(jid, sender) => { CommandMessage::AddContact(jid, sender) => {
sender.send(Err(RosterError::Write(WriteError::Disconnected))); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::BuddyRequest(jid, sender) => { CommandMessage::BuddyRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::SubscriptionRequest(jid, sender) => { CommandMessage::SubscriptionRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::AcceptBuddyRequest(jid, sender) => { CommandMessage::AcceptBuddyRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::AcceptSubscriptionRequest(jid, sender) => { CommandMessage::AcceptSubscriptionRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::UnsubscribeFromContact(jid, sender) => { CommandMessage::UnsubscribeFromContact(jid, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::UnsubscribeContact(jid, sender) => { CommandMessage::UnsubscribeContact(jid, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::UnfriendContact(jid, sender) => { CommandMessage::UnfriendContact(jid, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::DeleteContact(jid, sender) => { CommandMessage::DeleteContact(jid, sender) => {
sender.send(Err(RosterError::Write(WriteError::Disconnected))); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::UpdateContact(jid, contact_update, sender) => { CommandMessage::UpdateContact(jid, contact_update, sender) => {
sender.send(Err(RosterError::Write(WriteError::Disconnected))); sender.send(Err(Reason::Disconnected));
} }
CommandMessage::SetStatus(online, sender) => { CommandMessage::SetStatus(online, sender) => {
let result = db let result = db
.upsert_cached_status(online) .upsert_cached_status(online)
.await .await
.map_err(|e| StatusError::Cache(e.into())); .map_err(|e| StatusError(e.into()));
sender.send(result); sender.send(result);
} }
// TODO: offline message queue // TODO: offline message queue
CommandMessage::SendMessage(jid, body, sender) => { CommandMessage::SendMessage(jid, body, sender) => {
sender.send(Err(WriteError::Disconnected)); sender.send(Err(Reason::Disconnected));
}
CommandMessage::SendPresence(jid, presence, sender) => {
sender.send(Err(WriteError::Disconnected));
} }
} }
} }
@ -393,7 +377,7 @@ impl CommandMessage {
client_jid: Arc<Mutex<JID>>, client_jid: Arc<Mutex<JID>>,
db: Db, db: Db,
update_sender: mpsc::Sender<UpdateMessage>, update_sender: mpsc::Sender<UpdateMessage>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>, pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) { ) {
match self { match self {
CommandMessage::Connect => unreachable!(), CommandMessage::Connect => unreachable!(),
@ -435,12 +419,11 @@ impl CommandMessage {
Ok(Ok(())) => info!("roster request sent"), Ok(Ok(())) => info!("roster request sent"),
Ok(Err(e)) => { Ok(Err(e)) => {
// TODO: log errors if fail to send // TODO: log errors if fail to send
let _ = result_sender.send(Err(RosterError::Write(e.into()))); let _ = result_sender.send(Err(RosterError(e.into())));
return; return;
} }
Err(e) => { Err(e) => {
let _ = result_sender let _ = result_sender.send(Err(RosterError(e.into())));
.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
return; return;
} }
}; };
@ -460,41 +443,23 @@ impl CommandMessage {
items.into_iter().map(|item| item.into()).collect(); items.into_iter().map(|item| item.into()).collect();
if let Err(e) = db.replace_cached_roster(contacts.clone()).await { if let Err(e) = db.replace_cached_roster(contacts.clone()).await {
update_sender update_sender
.send(UpdateMessage::Error(Error::Roster(RosterError::Cache( .send(UpdateMessage::Error(Error::CacheUpdate(e.into())))
e.into(),
))))
.await; .await;
}; };
result_sender.send(Ok(contacts)); result_sender.send(Ok(contacts));
return; return;
} }
ref s @ Stanza::Iq(Iq {
from: _,
ref id,
to: _,
r#type,
lang: _,
query: _,
ref errors,
}) if *id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() {
result_sender.send(Err(RosterError::StanzaError(error.clone())));
} else {
result_sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
}
return;
}
s => { s => {
result_sender.send(Err(RosterError::UnexpectedStanza(s))); result_sender.send(Err(RosterError(Reason::UnexpectedStanza(s))));
return; return;
} }
}, },
Ok(Err(e)) => { Ok(Err(e)) => {
result_sender.send(Err(RosterError::Read(e))); result_sender.send(Err(RosterError(e.into())));
return; return;
} }
Err(e) => { Err(e) => {
result_sender.send(Err(RosterError::Write(WriteError::Actor(e.into())))); result_sender.send(Err(RosterError(e.into())));
return; return;
} }
} }
@ -555,8 +520,8 @@ impl CommandMessage {
} }
// TODO: write_handle send helper function // TODO: write_handle send helper function
let result = write_handle.write(set_stanza).await; let result = write_handle.write(set_stanza).await;
if let Err(e) = result { if let Err(_) = result {
sender.send(Err(RosterError::Write(e))); sender.send(result);
return; return;
} }
let iq_result = recv.await; let iq_result = recv.await;
@ -575,24 +540,24 @@ impl CommandMessage {
sender.send(Ok(())); sender.send(Ok(()));
return; return;
} }
ref s @ Stanza::Iq(Iq { Stanza::Iq(Iq {
from: _, from: _,
ref id, id,
to: _, to: _,
r#type, r#type,
lang: _, lang: _,
query: _, query: _,
ref errors, errors,
}) if *id == iq_id && r#type == IqType::Error => { }) if id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() { if let Some(error) = errors.first() {
sender.send(Err(RosterError::StanzaError(error.clone()))); sender.send(Err(Reason::Stanza(Some(error.clone()))));
} else { } else {
sender.send(Err(RosterError::UnexpectedStanza(s.clone()))); sender.send(Err(Reason::Stanza(None)));
} }
return; return;
} }
s => { s => {
sender.send(Err(RosterError::UnexpectedStanza(s))); sender.send(Err(Reason::UnexpectedStanza(s)));
return; return;
} }
}, },
@ -602,7 +567,7 @@ impl CommandMessage {
} }
}, },
Err(e) => { Err(e) => {
sender.send(Err(RosterError::Write(WriteError::Actor(e.into())))); sender.send(Err(e.into()));
return; return;
} }
} }
@ -800,8 +765,8 @@ impl CommandMessage {
pending_iqs.lock().await.insert(iq_id.clone(), send); pending_iqs.lock().await.insert(iq_id.clone(), send);
} }
let result = write_handle.write(set_stanza).await; let result = write_handle.write(set_stanza).await;
if let Err(e) = result { if let Err(_) = result {
sender.send(Err(RosterError::Write(e))); sender.send(result);
return; return;
} }
let iq_result = recv.await; let iq_result = recv.await;
@ -820,24 +785,24 @@ impl CommandMessage {
sender.send(Ok(())); sender.send(Ok(()));
return; return;
} }
ref s @ Stanza::Iq(Iq { Stanza::Iq(Iq {
from: _, from: _,
ref id, id,
to: _, to: _,
r#type, r#type,
lang: _, lang: _,
query: _, query: _,
ref errors, errors,
}) if *id == iq_id && r#type == IqType::Error => { }) if id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() { if let Some(error) = errors.first() {
sender.send(Err(RosterError::StanzaError(error.clone()))); sender.send(Err(Reason::Stanza(Some(error.clone()))));
} else { } else {
sender.send(Err(RosterError::UnexpectedStanza(s.clone()))); sender.send(Err(Reason::Stanza(None)));
} }
return; return;
} }
s => { s => {
sender.send(Err(RosterError::UnexpectedStanza(s))); sender.send(Err(Reason::UnexpectedStanza(s)));
return; return;
} }
}, },
@ -847,7 +812,7 @@ impl CommandMessage {
} }
}, },
Err(e) => { Err(e) => {
sender.send(Err(RosterError::Write(WriteError::Actor(e.into())))); sender.send(Err(e.into()));
return; return;
} }
} }
@ -888,8 +853,8 @@ impl CommandMessage {
pending_iqs.lock().await.insert(iq_id.clone(), send); pending_iqs.lock().await.insert(iq_id.clone(), send);
} }
let result = write_handle.write(set_stanza).await; let result = write_handle.write(set_stanza).await;
if let Err(e) = result { if let Err(_) = result {
sender.send(Err(RosterError::Write(e))); sender.send(result);
return; return;
} }
let iq_result = recv.await; let iq_result = recv.await;
@ -908,24 +873,24 @@ impl CommandMessage {
sender.send(Ok(())); sender.send(Ok(()));
return; return;
} }
ref s @ Stanza::Iq(Iq { Stanza::Iq(Iq {
from: _, from: _,
ref id, id,
to: _, to: _,
r#type, r#type,
lang: _, lang: _,
query: _, query: _,
ref errors, errors,
}) if *id == iq_id && r#type == IqType::Error => { }) if id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() { if let Some(error) = errors.first() {
sender.send(Err(RosterError::StanzaError(error.clone()))); sender.send(Err(Reason::Stanza(Some(error.clone()))));
} else { } else {
sender.send(Err(RosterError::UnexpectedStanza(s.clone()))); sender.send(Err(Reason::Stanza(None)));
} }
return; return;
} }
s => { s => {
sender.send(Err(RosterError::UnexpectedStanza(s))); sender.send(Err(Reason::UnexpectedStanza(s)));
return; return;
} }
}, },
@ -935,7 +900,7 @@ impl CommandMessage {
} }
}, },
Err(e) => { Err(e) => {
sender.send(Err(RosterError::Write(WriteError::Actor(e.into())))); sender.send(Err(e.into()));
return; return;
} }
} }
@ -944,16 +909,13 @@ impl CommandMessage {
let result = db.upsert_cached_status(online.clone()).await; let result = db.upsert_cached_status(online.clone()).await;
if let Err(e) = result { if let Err(e) = result {
let _ = update_sender let _ = update_sender
.send(UpdateMessage::Error(Error::SetStatus(StatusError::Cache( .send(UpdateMessage::Error(Error::CacheUpdate(e.into())))
e.into(),
))))
.await; .await;
} }
let result = write_handle let result = write_handle
.write(Stanza::Presence(online.into())) .write(Stanza::Presence(online.into()))
.await .await
.map_err(|e| StatusError::Write(e)); .map_err(|e| StatusError(e));
// .map_err(|e| StatusError::Write(e));
let _ = sender.send(result); let _ = sender.send(result);
} }
// TODO: offline message queue // TODO: offline message queue
@ -989,9 +951,7 @@ impl CommandMessage {
}; };
if let Err(e) = db.create_message(message, jid).await.map_err(|e| e.into()) if let Err(e) = db.create_message(message, jid).await.map_err(|e| e.into())
{ {
let _ = update_sender.send(UpdateMessage::Error(Error::MessageSend( let _ = update_sender.send(UpdateMessage::Error(Error::CacheUpdate(e)));
error::MessageSendError::MessageHistory(e),
)));
} }
let _ = sender.send(Ok(())); let _ = sender.send(Ok(()));
} }
@ -1000,15 +960,6 @@ impl CommandMessage {
} }
} }
} }
CommandMessage::SendPresence(jid, presence, sender) => {
let mut presence: stanza::client::presence::Presence = presence.into();
if let Some(jid) = jid {
presence.to = Some(jid);
};
let result = write_handle.write(Stanza::Presence(presence)).await;
// .map_err(|e| StatusError::Write(e));
let _ = sender.send(result);
}
} }
} }
} }
@ -1043,18 +994,16 @@ impl DerefMut for LuzHandle {
} }
impl LuzHandle { impl LuzHandle {
// TODO: database creation separate
pub async fn new( pub async fn new(
jid: JID, jid: JID,
password: String, password: String,
db: &str, db: &str,
) -> Result<(Self, mpsc::Receiver<UpdateMessage>), DatabaseError> { ) -> Result<(Self, mpsc::Receiver<UpdateMessage>), Reason> {
let db = SqlitePool::connect(db).await?; let db = SqlitePool::connect(db).await?;
let (command_sender, command_receiver) = mpsc::channel(20); let (command_sender, command_receiver) = mpsc::channel(20);
let (update_sender, update_receiver) = mpsc::channel(20); let (update_sender, update_receiver) = mpsc::channel(20);
// might be bad, first supervisor shutdown notification oneshot is never used (disgusting) // might be bad, first supervisor shutdown notification oneshot is never used (disgusting)
let (sup_send, sup_recv) = oneshot::channel(); let (sup_send, sup_recv) = oneshot::channel();
let mut sup_recv = sup_recv.fuse();
let actor = Luz::new( let actor = Luz::new(
command_sender.clone(), command_sender.clone(),
@ -1075,59 +1024,8 @@ impl LuzHandle {
update_receiver, update_receiver,
)) ))
} }
pub async fn connect(&self) {
self.send(CommandMessage::Connect).await;
}
pub async fn disconnect(&self, offline: Offline) {
self.send(CommandMessage::Disconnect(offline)).await;
}
// pub async fn get_roster(&self) -> Result<Vec<Contact>, RosterError> {
// let (send, recv) = oneshot::channel();
// self.send(CommandMessage::GetRoster(send)).await.map_err(|e| RosterError::)?;
// Ok(recv.await?)
// }
// pub async fn get_chats(&self) -> Result<Vec<Chat>, Error> {}
// pub async fn get_chat(&self, jid: JID) -> Result<Chat, Error> {}
// pub async fn get_messages(&self, jid: JID) -> Result<Vec<Message>, Error> {}
// pub async fn delete_chat(&self, jid: JID) -> Result<(), Error> {}
// pub async fn delete_message(&self, id: Uuid) -> Result<(), Error> {}
// pub async fn get_user(&self, jid: JID) -> Result<User, Error> {}
// pub async fn add_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn buddy_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn subscription_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn accept_buddy_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn accept_subscription_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unsubscribe_from_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unsubscribe_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unfriend_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn delete_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn update_contact(&self, jid: JID, update: ContactUpdate) -> Result<(), Error> {}
// pub async fn set_status(&self, online: Online) -> Result<(), Error> {}
// pub async fn send_message(&self, jid: JID, body: Body) -> Result<(), Error> {}
} }
// TODO: generate methods for each with a macro
pub enum CommandMessage { pub enum CommandMessage {
// TODO: login invisible xep-0186 // TODO: login invisible xep-0186
/// connect to XMPP chat server. gets roster and publishes initial presence. /// connect to XMPP chat server. gets roster and publishes initial presence.
@ -1138,51 +1036,46 @@ pub enum CommandMessage {
GetRoster(oneshot::Sender<Result<Vec<Contact>, RosterError>>), GetRoster(oneshot::Sender<Result<Vec<Contact>, RosterError>>),
/// get all chats. chat will include 10 messages in their message Vec (enough for chat previews) /// get all chats. chat will include 10 messages in their message Vec (enough for chat previews)
// TODO: paging and filtering // TODO: paging and filtering
GetChats(oneshot::Sender<Result<Vec<Chat>, DatabaseError>>), GetChats(oneshot::Sender<Result<Vec<Chat>, Reason>>),
/// get a specific chat by jid /// get a specific chat by jid
GetChat(JID, oneshot::Sender<Result<Chat, DatabaseError>>), GetChat(JID, oneshot::Sender<Result<Chat, Reason>>),
/// get message history for chat (does appropriate mam things) /// get message history for chat (does appropriate mam things)
// TODO: paging and filtering // TODO: paging and filtering
GetMessages(JID, oneshot::Sender<Result<Vec<Message>, DatabaseError>>), GetMessages(JID, oneshot::Sender<Result<Vec<Message>, Reason>>),
/// delete a chat from your chat history, along with all the corresponding messages /// delete a chat from your chat history, along with all the corresponding messages
DeleteChat(JID, oneshot::Sender<Result<(), DatabaseError>>), DeleteChat(JID, oneshot::Sender<Result<(), Reason>>),
/// delete a message from your chat history /// delete a message from your chat history
DeleteMessage(Uuid, oneshot::Sender<Result<(), DatabaseError>>), DeleteMessage(Uuid, oneshot::Sender<Result<(), Reason>>),
/// get a user from your users database /// get a user from your users database
GetUser(JID, oneshot::Sender<Result<User, DatabaseError>>), GetUser(JID, oneshot::Sender<Result<User, Reason>>),
/// add a contact to your roster, with a status of none, no subscriptions. /// add a contact to your roster, with a status of none, no subscriptions.
AddContact(JID, oneshot::Sender<Result<(), RosterError>>), // TODO: for all these, consider returning with oneshot::Sender<Result<(), Error>>
AddContact(JID, oneshot::Sender<Result<(), Reason>>),
/// send a friend request i.e. a subscription request with a subscription pre-approval. if not already added to roster server adds to roster. /// send a friend request i.e. a subscription request with a subscription pre-approval. if not already added to roster server adds to roster.
BuddyRequest(JID, oneshot::Sender<Result<(), WriteError>>), BuddyRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// send a subscription request, without pre-approval. if not already added to roster server adds to roster. /// send a subscription request, without pre-approval. if not already added to roster server adds to roster.
SubscriptionRequest(JID, oneshot::Sender<Result<(), WriteError>>), SubscriptionRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// accept a friend request by accepting a pending subscription and sending a subscription request back. if not already added to roster adds to roster. /// accept a friend request by accepting a pending subscription and sending a subscription request back. if not already added to roster adds to roster.
AcceptBuddyRequest(JID, oneshot::Sender<Result<(), WriteError>>), AcceptBuddyRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// accept a pending subscription and doesn't send a subscription request back. if not already added to roster adds to roster. /// accept a pending subscription and doesn't send a subscription request back. if not already added to roster adds to roster.
AcceptSubscriptionRequest(JID, oneshot::Sender<Result<(), WriteError>>), AcceptSubscriptionRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// unsubscribe to a contact, but don't remove their subscription. /// unsubscribe to a contact, but don't remove their subscription.
UnsubscribeFromContact(JID, oneshot::Sender<Result<(), WriteError>>), UnsubscribeFromContact(JID, oneshot::Sender<Result<(), Reason>>),
/// stop a contact from being subscribed, but stay subscribed to the contact. /// stop a contact from being subscribed, but stay subscribed to the contact.
UnsubscribeContact(JID, oneshot::Sender<Result<(), WriteError>>), UnsubscribeContact(JID, oneshot::Sender<Result<(), Reason>>),
/// remove subscriptions to and from contact, but keep in roster. /// remove subscriptions to and from contact, but keep in roster.
UnfriendContact(JID, oneshot::Sender<Result<(), WriteError>>), UnfriendContact(JID, oneshot::Sender<Result<(), Reason>>),
/// remove a contact from the contact list. will remove subscriptions if not already done then delete contact from roster. /// remove a contact from the contact list. will remove subscriptions if not already done then delete contact from roster.
DeleteContact(JID, oneshot::Sender<Result<(), RosterError>>), DeleteContact(JID, oneshot::Sender<Result<(), Reason>>),
/// update contact. contact details will be overwritten with the contents of the contactupdate struct. /// update contact. contact details will be overwritten with the contents of the contactupdate struct.
UpdateContact(JID, ContactUpdate, oneshot::Sender<Result<(), RosterError>>), UpdateContact(JID, ContactUpdate, oneshot::Sender<Result<(), Reason>>),
/// set online status. if disconnected, will be cached so when client connects, will be sent as the initial presence. /// set online status. if disconnected, will be cached so when client connects, will be sent as the initial presence.
SetStatus(Online, oneshot::Sender<Result<(), StatusError>>), SetStatus(Online, oneshot::Sender<Result<(), StatusError>>),
/// send presence stanza
SendPresence(
Option<JID>,
Presence,
oneshot::Sender<Result<(), WriteError>>,
),
/// send a directed presence (usually to a non-contact). /// send a directed presence (usually to a non-contact).
// TODO: should probably make it so people can add non-contact auto presence sharing in the client (most likely through setting an internal setting) // TODO: should probably make it so people can add non-contact auto presence sharing in the client (most likely through setting an internal setting)
/// send a message to a jid (any kind of jid that can receive a message, e.g. a user or a /// send a message to a jid (any kind of jid that can receive a message, e.g. a user or a
/// chatroom). if disconnected, will be cached so when client connects, message will be sent. /// chatroom). if disconnected, will be cached so when client connects, message will be sent.
SendMessage(JID, Body, oneshot::Sender<Result<(), WriteError>>), SendMessage(JID, Body, oneshot::Sender<Result<(), Reason>>),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View File

@ -12,13 +12,9 @@ use tracing::info;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let (luz, mut recv) = LuzHandle::new( let db = SqlitePool::connect("./luz.db").await.unwrap();
"test@blos.sm".try_into().unwrap(), let (luz, mut recv) =
"slayed".to_string(), LuzHandle::new("test@blos.sm".try_into().unwrap(), "slayed".to_string(), db);
"./luz.db",
)
.await
.unwrap();
tokio::spawn(async move { tokio::spawn(async move {
while let Some(msg) = recv.recv().await { while let Some(msg) = recv.recv().await {

View File

@ -91,31 +91,3 @@ impl From<Online> for stanza::client::presence::Presence {
} }
} }
} }
impl From<Offline> for stanza::client::presence::Presence {
fn from(value: Offline) -> Self {
Self {
from: None,
id: None,
to: None,
r#type: Some(stanza::client::presence::PresenceType::Unavailable),
lang: None,
show: None,
status: value.status.map(|status| stanza::client::presence::Status {
lang: None,
status: String1024(status),
}),
priority: None,
errors: Vec::new(),
}
}
}
impl From<Presence> for stanza::client::presence::Presence {
fn from(value: Presence) -> Self {
match value {
Presence::Online(online) => online.into(),
Presence::Offline(offline) => offline.into(),
}
}
}

View File

@ -6,4 +6,3 @@ edition = "2021"
[dependencies] [dependencies]
peanuts = { version = "0.1.0", path = "../../peanuts" } peanuts = { version = "0.1.0", path = "../../peanuts" }
jid = { version = "0.1.0", path = "../jid" } jid = { version = "0.1.0", path = "../jid" }
thiserror = "2.0.11"

View File

@ -13,8 +13,8 @@ pub struct Bind {
impl FromElement for Bind { impl FromElement for Bind {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("bind")?; element.check_name("bind");
element.check_name(XMLNS)?; element.check_name(XMLNS);
let r#type = element.pop_child_opt()?; let r#type = element.pop_child_opt()?;
@ -61,8 +61,8 @@ pub struct FullJidType(pub JID);
impl FromElement for FullJidType { impl FromElement for FullJidType {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("jid")?; element.check_name("jid");
element.check_namespace(XMLNS)?; element.check_namespace(XMLNS);
let jid = element.pop_value()?; let jid = element.pop_value()?;

View File

@ -1,16 +1,14 @@
use std::fmt::Display;
use std::str::FromStr; use std::str::FromStr;
use peanuts::element::{FromElement, IntoElement}; use peanuts::element::{FromElement, IntoElement};
use peanuts::{DeserializeError, Element}; use peanuts::{DeserializeError, Element};
use thiserror::Error;
use crate::stanza_error::Error as StanzaError; use crate::stanza_error::Error as StanzaError;
use crate::stanza_error::Text; use crate::stanza_error::Text;
use super::XMLNS; use super::XMLNS;
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug)]
pub struct Error { pub struct Error {
by: Option<String>, by: Option<String>,
r#type: ErrorType, r#type: ErrorType,
@ -19,22 +17,6 @@ pub struct Error {
text: Option<Text>, text: Option<Text>,
} }
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}, {}", self.r#type, self.error)?;
if let Some(text) = &self.text {
if let Some(text) = &text.text {
write!(f, ": {}", text)?;
}
}
if let Some(by) = &self.by {
write!(f, " (error returned by `{}`)", by)?;
}
Ok(())
}
}
impl FromElement for Error { impl FromElement for Error {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("error")?; element.check_name("error")?;
@ -73,18 +55,6 @@ pub enum ErrorType {
Wait, Wait,
} }
impl Display for ErrorType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorType::Auth => f.write_str("auth"),
ErrorType::Cancel => f.write_str("cancel"),
ErrorType::Continue => f.write_str("continue"),
ErrorType::Modify => f.write_str("modify"),
ErrorType::Wait => f.write_str("wait"),
}
}
}
impl FromStr for ErrorType { impl FromStr for ErrorType {
type Err = DeserializeError; type Err = DeserializeError;
@ -99,3 +69,15 @@ impl FromStr for ErrorType {
} }
} }
} }
impl ToString for ErrorType {
fn to_string(&self) -> String {
match self {
ErrorType::Auth => "auth".to_string(),
ErrorType::Cancel => "cancel".to_string(),
ErrorType::Continue => "continue".to_string(),
ErrorType::Modify => "modify".to_string(),
ErrorType::Wait => "wait".to_string(),
}
}
}

View File

@ -15,7 +15,7 @@ use crate::{
use super::XMLNS; use super::XMLNS;
#[derive(Debug, Clone)] #[derive(Debug)]
pub struct Iq { pub struct Iq {
pub from: Option<JID>, pub from: Option<JID>,
pub id: String, pub id: String,

View File

@ -8,7 +8,7 @@ use peanuts::{
use super::XMLNS; use super::XMLNS;
#[derive(Debug, Clone)] #[derive(Debug)]
pub struct Message { pub struct Message {
pub from: Option<JID>, pub from: Option<JID>,
pub id: Option<String>, pub id: Option<String>,

View File

@ -1,7 +1,7 @@
use iq::Iq; use iq::Iq;
use message::Message; use message::Message;
use peanuts::{ use peanuts::{
element::{Content, ContentBuilder, FromContent, FromElement, IntoContent}, element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement},
DeserializeError, DeserializeError,
}; };
use presence::Presence; use presence::Presence;
@ -16,7 +16,7 @@ pub mod presence;
pub const XMLNS: &str = "jabber:client"; pub const XMLNS: &str = "jabber:client";
/// TODO: End tag /// TODO: End tag
#[derive(Debug, Clone)] #[derive(Debug)]
pub enum Stanza { pub enum Stanza {
Message(Message), Message(Message),
Presence(Presence), Presence(Presence),

View File

@ -8,7 +8,7 @@ use peanuts::{
use super::{error::Error, XMLNS}; use super::{error::Error, XMLNS};
#[derive(Debug, Clone)] #[derive(Debug)]
pub struct Presence { pub struct Presence {
pub from: Option<JID>, pub from: Option<JID>,
pub id: Option<String>, pub id: Option<String>,

View File

@ -1,10 +1,9 @@
use std::{fmt::Display, ops::Deref}; use std::ops::Deref;
use peanuts::{ use peanuts::{
element::{FromElement, IntoElement}, element::{FromElement, IntoElement},
DeserializeError, Element, DeserializeError, Element,
}; };
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
@ -169,48 +168,12 @@ impl IntoElement for Response {
} }
} }
#[derive(Error, Debug, Clone)] #[derive(Debug)]
pub struct Failure { pub struct Failure {
r#type: Option<FailureType>, r#type: Option<FailureType>,
text: Option<Text>, text: Option<Text>,
} }
impl Display for Failure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut had_type = false;
let mut had_text = false;
if let Some(r#type) = &self.r#type {
had_type = true;
match r#type {
FailureType::Aborted => f.write_str("aborted"),
FailureType::AccountDisabled => f.write_str("account disabled"),
FailureType::CredentialsExpired => f.write_str("credentials expired"),
FailureType::EncryptionRequired => f.write_str("encryption required"),
FailureType::IncorrectEncoding => f.write_str("incorrect encoding"),
FailureType::InvalidAuthzid => f.write_str("invalid authzid"),
FailureType::InvalidMechanism => f.write_str("invalid mechanism"),
FailureType::MalformedRequest => f.write_str("malformed request"),
FailureType::MechanismTooWeak => f.write_str("mechanism too weak"),
FailureType::NotAuthorized => f.write_str("not authorized"),
FailureType::TemporaryAuthFailure => f.write_str("temporary auth failure"),
}?;
}
if let Some(text) = &self.text {
if let Some(text) = &text.text {
if had_type {
f.write_str(": ")?;
}
f.write_str(text)?;
had_text = true;
}
}
if !had_type && !had_text {
f.write_str("failure")?;
}
Ok(())
}
}
impl FromElement for Failure { impl FromElement for Failure {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("failure")?; element.check_name("failure")?;
@ -223,29 +186,18 @@ impl FromElement for Failure {
} }
} }
#[derive(Error, Debug, Clone)] #[derive(Debug)]
pub enum FailureType { pub enum FailureType {
#[error("aborted")]
Aborted, Aborted,
#[error("account disabled")]
AccountDisabled, AccountDisabled,
#[error("credentials expired")]
CredentialsExpired, CredentialsExpired,
#[error("encryption required")]
EncryptionRequired, EncryptionRequired,
#[error("incorrect encoding")]
IncorrectEncoding, IncorrectEncoding,
#[error("invalid authzid")]
InvalidAuthzid, InvalidAuthzid,
#[error("invalid mechanism")]
InvalidMechanism, InvalidMechanism,
#[error("malformed request")]
MalformedRequest, MalformedRequest,
#[error("mechanism too weak")]
MechanismTooWeak, MechanismTooWeak,
#[error("not authorized")]
NotAuthorized, NotAuthorized,
#[error("temporary auth failure")]
TemporaryAuthFailure, TemporaryAuthFailure,
} }
@ -268,9 +220,8 @@ impl FromElement for FailureType {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug)]
pub struct Text { pub struct Text {
#[allow(dead_code)]
lang: Option<String>, lang: Option<String>,
text: Option<String>, text: Option<String>,
} }

View File

@ -4,55 +4,32 @@ use peanuts::{
element::{FromElement, IntoElement}, element::{FromElement, IntoElement},
Element, XML_NS, Element, XML_NS,
}; };
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-stanzas"; pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-stanzas";
#[derive(Error, Clone, Debug)] #[derive(Clone, Debug)]
pub enum Error { pub enum Error {
#[error("bad request")]
BadRequest, BadRequest,
#[error("conflict")]
Conflict, Conflict,
#[error("feature not implemented")]
FeatureNotImplemented, FeatureNotImplemented,
#[error("forbidden")]
Forbidden, Forbidden,
#[error("gone: {0:?}")]
Gone(Option<String>), Gone(Option<String>),
#[error("internal server error")]
InternalServerError, InternalServerError,
#[error("item not found")]
ItemNotFound, ItemNotFound,
#[error("JID malformed")] JidMalformed,
JIDMalformed,
#[error("not acceptable")]
NotAcceptable, NotAcceptable,
#[error("not allowed")]
NotAllowed, NotAllowed,
#[error("not authorized")]
NotAuthorized, NotAuthorized,
#[error("policy violation")]
PolicyViolation, PolicyViolation,
#[error("recipient unavailable")]
RecipientUnavailable, RecipientUnavailable,
#[error("redirect: {0:?}")]
Redirect(Option<String>), Redirect(Option<String>),
#[error("registration required")]
RegistrationRequired, RegistrationRequired,
#[error("remote server not found")]
RemoteServerNotFound, RemoteServerNotFound,
#[error("remote server timeout")]
RemoteServerTimeout, RemoteServerTimeout,
#[error("resource constraint")]
ResourceConstraint, ResourceConstraint,
#[error("service unavailable")]
ServiceUnavailable, ServiceUnavailable,
#[error("subscription required")]
SubscriptionRequired, SubscriptionRequired,
#[error("undefined condition")]
UndefinedCondition, UndefinedCondition,
#[error("unexpected request")]
UnexpectedRequest, UnexpectedRequest,
} }
@ -67,7 +44,7 @@ impl FromElement for Error {
(Some(XMLNS), "gone") => return Ok(Error::Gone(element.pop_value_opt()?)), (Some(XMLNS), "gone") => return Ok(Error::Gone(element.pop_value_opt()?)),
(Some(XMLNS), "internal-server-error") => error = Error::InternalServerError, (Some(XMLNS), "internal-server-error") => error = Error::InternalServerError,
(Some(XMLNS), "item-not-found") => error = Error::ItemNotFound, (Some(XMLNS), "item-not-found") => error = Error::ItemNotFound,
(Some(XMLNS), "jid-malformed") => error = Error::JIDMalformed, (Some(XMLNS), "jid-malformed") => error = Error::JidMalformed,
(Some(XMLNS), "not-acceptable") => error = Error::NotAcceptable, (Some(XMLNS), "not-acceptable") => error = Error::NotAcceptable,
(Some(XMLNS), "not-allowed") => error = Error::NotAllowed, (Some(XMLNS), "not-allowed") => error = Error::NotAllowed,
(Some(XMLNS), "not-authorized") => error = Error::NotAuthorized, (Some(XMLNS), "not-authorized") => error = Error::NotAuthorized,
@ -101,7 +78,7 @@ impl IntoElement for Error {
Error::Gone(r) => Element::builder("gone", Some(XMLNS)).push_text_opt(r.clone()), Error::Gone(r) => Element::builder("gone", Some(XMLNS)).push_text_opt(r.clone()),
Error::InternalServerError => Element::builder("internal-server-error", Some(XMLNS)), Error::InternalServerError => Element::builder("internal-server-error", Some(XMLNS)),
Error::ItemNotFound => Element::builder("item-not-found", Some(XMLNS)), Error::ItemNotFound => Element::builder("item-not-found", Some(XMLNS)),
Error::JIDMalformed => Element::builder("jid-malformed", Some(XMLNS)), Error::JidMalformed => Element::builder("jid-malformed", Some(XMLNS)),
Error::NotAcceptable => Element::builder("not-acceptable", Some(XMLNS)), Error::NotAcceptable => Element::builder("not-acceptable", Some(XMLNS)),
Error::NotAllowed => Element::builder("not-allowed", Some(XMLNS)), Error::NotAllowed => Element::builder("not-allowed", Some(XMLNS)),
Error::NotAuthorized => Element::builder("not-authorized", Some(XMLNS)), Error::NotAuthorized => Element::builder("not-authorized", Some(XMLNS)),
@ -125,7 +102,7 @@ impl IntoElement for Error {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Text { pub struct Text {
lang: Option<String>, lang: Option<String>,
pub text: Option<String>, text: Option<String>,
} }
impl FromElement for Text { impl FromElement for Text {

View File

@ -1,5 +1,7 @@
use std::collections::{HashMap, HashSet};
use peanuts::{ use peanuts::{
element::{FromElement, IntoElement}, element::{Content, FromElement, IntoElement, Name, NamespaceDeclaration},
Element, Element,
}; };

View File

@ -1,16 +1,15 @@
use std::fmt::Display; use std::collections::{HashMap, HashSet};
use jid::JID; use jid::JID;
use peanuts::element::{ElementBuilder, FromElement, IntoElement}; use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration};
use peanuts::Element; use peanuts::{element::Name, Element};
use thiserror::Error;
use crate::bind; use crate::bind;
use super::client;
use super::sasl::{self, Mechanisms}; use super::sasl::{self, Mechanisms};
use super::starttls::{self, StartTls}; use super::starttls::{self, StartTls};
use super::stream_error::{Error as StreamError, Text}; use super::stream_error::{Error as StreamError, Text};
use super::{client, stream_error};
pub const XMLNS: &str = "http://etherx.jabber.org/streams"; pub const XMLNS: &str = "http://etherx.jabber.org/streams";
@ -179,24 +178,12 @@ impl FromElement for Feature {
} }
} }
#[derive(Error, Debug, Clone)] #[derive(Debug)]
pub struct Error { pub struct Error {
error: StreamError, error: StreamError,
text: Option<Text>, text: Option<Text>,
} }
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)?;
if let Some(text) = &self.text {
if let Some(text) = &text.text {
write!(f, ": {}", text)?;
}
}
Ok(())
}
}
impl FromElement for Error { impl FromElement for Error {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("error")?; element.check_name("error")?;

View File

@ -2,61 +2,35 @@ use peanuts::{
element::{FromElement, IntoElement}, element::{FromElement, IntoElement},
DeserializeError, Element, XML_NS, DeserializeError, Element, XML_NS,
}; };
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams"; pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams";
#[derive(Error, Clone, Debug)] #[derive(Clone, Debug)]
pub enum Error { pub enum Error {
#[error("bad format")]
BadFormat, BadFormat,
#[error("bad namespace prefix")]
BadNamespacePrefix, BadNamespacePrefix,
#[error("conflict")]
Conflict, Conflict,
#[error("connection timeout")]
ConnectionTimeout, ConnectionTimeout,
#[error("host gone")]
HostGone, HostGone,
#[error("host unknown")]
HostUnknown, HostUnknown,
#[error("improper addressing")]
ImproperAddressing, ImproperAddressing,
#[error("internal server error")]
InternalServerError, InternalServerError,
#[error("invalid from")]
InvalidFrom, InvalidFrom,
#[error("invalid id")]
InvalidId, InvalidId,
#[error("invalid namespace")]
InvalidNamespace, InvalidNamespace,
#[error("invalid xml")]
InvalidXml, InvalidXml,
#[error("not authorized")]
NotAuthorized, NotAuthorized,
#[error("not well formed")]
NotWellFormed, NotWellFormed,
#[error("policy violation")]
PolicyViolation, PolicyViolation,
#[error("remote connection failed")]
RemoteConnectionFailed, RemoteConnectionFailed,
#[error("reset")]
Reset, Reset,
#[error("resource constraint")]
ResourceConstraint, ResourceConstraint,
#[error("restricted xml")]
RestrictedXml, RestrictedXml,
#[error("see other host: {0:?}")]
SeeOtherHost(Option<String>), SeeOtherHost(Option<String>),
#[error("system shutdown")]
SystemShutdown, SystemShutdown,
#[error("undefined condition")]
UndefinedCondition, UndefinedCondition,
#[error("unsupported encoding")]
UnsupportedEncoding, UnsupportedEncoding,
#[error("unsupported stanza type")]
UnsupportedStanzaType, UnsupportedStanzaType,
#[error("unsupported version")]
UnsupportedVersion, UnsupportedVersion,
} }
@ -138,7 +112,7 @@ impl IntoElement for Error {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Text { pub struct Text {
pub text: Option<String>, text: Option<String>,
lang: Option<String>, lang: Option<String>,
} }