Compare commits

..

No commits in common. "861db1197db6d3f9ee0ce3f6e3434666da05a95c" and "7dc1b1f35d02485b75c0373701dd09cec874ac4d" have entirely different histories.

30 changed files with 403 additions and 673 deletions

View File

@ -30,7 +30,6 @@ futures = "0.3.31"
take_mut = "0.2.2"
pin-project-lite = "0.2.15"
pin-project = "1.1.7"
thiserror = "2.0.11"
[dev-dependencies]
test-log = { version = "0.2", features = ["trace"] }

View File

@ -1,17 +1,29 @@
use std::{
borrow::Borrow,
future::Future,
pin::pin,
sync::Arc,
task::{ready, Poll},
};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig;
use stanza::{
client::Stanza,
sasl::Mechanisms,
stream::{Feature, Features},
};
use tokio::sync::Mutex;
use crate::{
connection::{Tls, Unencrypted},
jabber_stream::bound_stream::BoundJabberStream,
jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream},
Connection, Error, JabberStream, Result, JID,
};
pub async fn connect_and_login(
jid: &mut JID,
mut jid: &mut JID,
password: impl AsRef<str>,
server: &mut String,
) -> Result<BoundJabberStream<Tls>> {
@ -19,8 +31,7 @@ pub async fn connect_and_login(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.as_ref().to_string(),
)
.map_err(|e| Error::SASL(e.into()))?;
)?;
let mut conn_state = Connecting::start(&server).await?;
loop {
match conn_state {
@ -109,8 +120,9 @@ pub enum InsecureConnecting {
#[cfg(test)]
mod tests {
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use futures::{SinkExt, StreamExt};
use jid::JID;
use stanza::{
client::{
@ -120,7 +132,7 @@ mod tests {
xep_0199::Ping,
};
use test_log::test;
use tokio::time::sleep;
use tokio::{sync::Mutex, time::sleep};
use tracing::info;
use super::connect_and_login;
@ -128,7 +140,7 @@ mod tests {
#[test(tokio::test)]
async fn login() {
let mut jid: JID = "test@blos.sm".try_into().unwrap();
let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
let client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
.await
.unwrap();
sleep(Duration::from_secs(5)).await

View File

@ -1,7 +1,9 @@
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::str::FromStr;
use std::sync::Arc;
use rsasl::config::SASLConfig;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls

View File

@ -1,58 +1,87 @@
use std::str::Utf8Error;
use std::sync::Arc;
use jid::ParseError;
use rsasl::mechname::MechanismNameError;
use stanza::client::error::Error as ClientError;
use stanza::sasl::Failure;
use stanza::stream::Error as StreamError;
use thiserror::Error;
use tokio::task::JoinError;
#[derive(Error, Debug, Clone)]
#[derive(Debug)]
pub enum Error {
#[error("connection")]
Connection,
#[error("utf8 decode: {0}")]
Utf8Decode(#[from] Utf8Error),
#[error("negotiation")]
Utf8Decode,
Negotiation,
#[error("tls required")]
TlsRequired,
#[error("already connected with tls")]
AlreadyTls,
// TODO: specify unsupported feature
#[error("unsupported feature")]
Unsupported,
#[error("jid missing localpart")]
NoLocalpart,
#[error("received unexpected element: {0:?}")]
AlreadyConnecting,
StreamClosed,
UnexpectedElement(peanuts::Element),
#[error("xml error: {0}")]
XML(#[from] peanuts::Error),
#[error("sasl error: {0}")]
SASL(#[from] SASLError),
#[error("jid error: {0}")]
JID(#[from] ParseError),
#[error("client stanza error: {0}")]
ClientError(#[from] ClientError),
#[error("stream error: {0}")]
StreamError(#[from] StreamError),
#[error("error missing")]
XML(peanuts::Error),
Deserialization(peanuts::DeserializeError),
SASL(SASLError),
JID(ParseError),
Authentication(Failure),
ClientError(ClientError),
StreamError(StreamError),
MissingError,
Disconnected,
Connecting,
JoinError(JoinError),
}
#[derive(Error, Debug, Clone)]
#[derive(Debug)]
pub enum SASLError {
#[error("sasl error: {0}")]
SASL(Arc<rsasl::prelude::SASLError>),
#[error("mechanism error: {0}")]
MechanismName(#[from] MechanismNameError),
#[error("authentication failure: {0}")]
Authentication(#[from] Failure),
SASL(rsasl::prelude::SASLError),
MechanismName(MechanismNameError),
}
impl From<rsasl::prelude::SASLError> for SASLError {
impl From<rsasl::prelude::SASLError> for Error {
fn from(e: rsasl::prelude::SASLError) -> Self {
Self::SASL(Arc::new(e))
Self::SASL(SASLError::SASL(e))
}
}
impl From<JoinError> for Error {
fn from(e: JoinError) -> Self {
Self::JoinError(e)
}
}
impl From<peanuts::DeserializeError> for Error {
fn from(e: peanuts::DeserializeError) -> Self {
Error::Deserialization(e)
}
}
impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e))
}
}
impl From<SASLError> for Error {
fn from(e: SASLError) -> Self {
Self::SASL(e)
}
}
impl From<Utf8Error> for Error {
fn from(_e: Utf8Error) -> Self {
Self::Utf8Decode
}
}
impl From<peanuts::Error> for Error {
fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
impl From<ParseError> for Error {
fn from(e: ParseError) -> Self {
Self::JID(e)
}
}

View File

@ -1,8 +1,10 @@
use std::pin::pin;
use std::str::{self, FromStr};
use std::sync::Arc;
use futures::{sink, stream, StreamExt};
use jid::JID;
use peanuts::element::IntoElement;
use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use stanza::bind::{Bind, BindType, FullJidType, ResourceType};
@ -133,16 +135,13 @@ where
let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms {
offered_mechs
.push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
}
debug!("{:?}", offered_mechs);
let mut session = sasl
.start_suggested(&offered_mechs)
.map_err(|e| Error::SASL(e.into()))?;
let mut session = sasl.start_suggested(&offered_mechs)?;
let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>>;
let mut data: Option<Vec<u8>> = None;
if !session.are_we_first() {
// if not first mention the mechanism then get challenge data
@ -177,7 +176,7 @@ where
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
debug!("we went first");
}
@ -208,7 +207,7 @@ where
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
}
}
@ -410,7 +409,13 @@ impl std::fmt::Debug for JabberStream<Unencrypted> {
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::connection::Connection;
use futures::sink;
use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)]
async fn start_stream() {

View File

@ -1,6 +1,9 @@
use std::ops::{Deref, DerefMut};
use tokio::io::{AsyncRead, AsyncWrite};
use peanuts::{Reader, Writer};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use crate::Error;
use super::{JabberReader, JabberStream, JabberWriter};

View File

@ -8,6 +8,7 @@ pub mod error;
pub mod jabber_stream;
pub use connection::Connection;
use connection::Tls;
pub use error::Error;
pub use jabber_stream::JabberStream;
pub use jid::JID;

View File

@ -3,8 +3,5 @@ name = "jid"
version = "0.1.0"
edition = "2021"
[features]
sqlx = ["dep:sqlx"]
[dependencies]
sqlx = { version = "0.8.3", features = ["sqlite"], optional = true }
sqlx = { version = "0.8.3", features = ["sqlite"] }

View File

@ -1,9 +1,8 @@
use std::{borrow::Cow, error::Error, fmt::Display, str::FromStr};
use std::{error::Error, fmt::Display, str::FromStr};
#[cfg(feature = "sqlx")]
use sqlx::Sqlite;
#[derive(PartialEq, Debug, Clone, Eq, Hash)]
#[derive(PartialEq, Debug, Clone, sqlx::Type, sqlx::Encode, Eq, Hash)]
pub struct JID {
// TODO: validate localpart (length, char]
pub localpart: Option<String>,
@ -11,36 +10,13 @@ pub struct JID {
pub resourcepart: Option<String>,
}
impl<'a> Into<Cow<'a, str>> for &'a JID {
fn into(self) -> Cow<'a, str> {
let a = self.to_string();
Cow::Owned(a)
}
}
impl Display for JID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(localpart) = &self.localpart {
f.write_str(localpart)?;
f.write_str("@")?;
}
f.write_str(&self.domainpart)?;
if let Some(resourcepart) = &self.resourcepart {
f.write_str("/")?;
f.write_str(resourcepart)?;
}
Ok(())
}
}
#[cfg(feature = "sqlx")]
// TODO: feature gate
impl sqlx::Type<Sqlite> for JID {
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
<&str as sqlx::Type<Sqlite>>::type_info()
}
}
#[cfg(feature = "sqlx")]
impl sqlx::Decode<'_, Sqlite> for JID {
fn decode(
value: <Sqlite as sqlx::Database>::ValueRef<'_>,
@ -51,7 +27,6 @@ impl sqlx::Decode<'_, Sqlite> for JID {
}
}
#[cfg(feature = "sqlx")]
impl sqlx::Encode<'_, Sqlite> for JID {
fn encode_by_ref(
&self,
@ -62,24 +37,12 @@ impl sqlx::Encode<'_, Sqlite> for JID {
}
}
#[derive(Debug, Clone)]
pub enum JIDError {
NoResourcePart,
ParseError(ParseError),
}
impl Display for JIDError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JIDError::NoResourcePart => f.write_str("resourcepart missing"),
JIDError::ParseError(parse_error) => parse_error.fmt(f),
}
}
}
impl Error for JIDError {}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum ParseError {
Empty,
Malformed(String),
@ -184,6 +147,21 @@ impl TryFrom<&str> for JID {
}
}
impl std::fmt::Display for JID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}{}{}",
self.localpart.clone().map(|l| l + "@").unwrap_or_default(),
self.domainpart,
self.resourcepart
.clone()
.map(|r| "/".to_owned() + &r)
.unwrap_or_default()
)
}
}
#[cfg(test)]
mod tests {
use super::*;

1
luz/.gitignore vendored
View File

@ -1,2 +1 @@
luz.db
.sqlx/

View File

@ -7,7 +7,7 @@ edition = "2021"
futures = "0.3.31"
jabber = { version = "0.1.0", path = "../jabber" }
peanuts = { version = "0.1.0", path = "../../peanuts" }
jid = { version = "0.1.0", path = "../jid", features = ["sqlx"] }
jid = { version = "0.1.0", path = "../jid" }
sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "uuid"] }
stanza = { version = "0.1.0", path = "../stanza" }
tokio = "1.42.0"
@ -16,4 +16,3 @@ tokio-util = "0.7.13"
tracing = "0.1.41"
tracing-subscriber = "0.3.19"
uuid = { version = "1.13.1", features = ["v4"] }
thiserror = "2.0.11"

View File

@ -20,7 +20,7 @@ use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage};
use crate::{
db::Db,
error::{Error, ReadError, WriteError},
error::{Error, Reason},
UpdateMessage,
};
@ -36,7 +36,7 @@ pub struct Supervisor {
tokio::task::JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>,
sender: mpsc::Sender<UpdateMessage>,
writer_handle: WriteControlHandle,
@ -62,7 +62,7 @@ pub enum State {
tokio::task::JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
),
),
}
@ -77,7 +77,7 @@ impl Supervisor {
JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>,
sender: mpsc::Sender<UpdateMessage>,
writer_handle: WriteControlHandle,
@ -180,7 +180,7 @@ impl Supervisor {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_state.close();
while let Some(msg) = write_state.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
let _ = msg.respond_to.send(Err(Reason::LostConnection));
}
// TODO: is this the correct error?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -227,9 +227,9 @@ impl Supervisor {
Err(e) => {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_recv.close();
let _ = write_msg.respond_to.send(Err(WriteError::LostConnection));
let _ = write_msg.respond_to.send(Err(Reason::LostConnection));
while let Some(msg) = write_recv.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
let _ = msg.respond_to.send(Err(Reason::LostConnection));
}
// TODO: is this the correct error to send?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -278,10 +278,10 @@ impl Supervisor {
// if reconnection failure, respond to all current messages with lost connection error.
write_receiver.close();
if let Some(msg) = retry_msg {
msg.respond_to.send(Err(WriteError::LostConnection));
msg.respond_to.send(Err(Reason::LostConnection));
}
while let Some(msg) = write_receiver.recv().await {
msg.respond_to.send(Err(WriteError::LostConnection));
msg.respond_to.send(Err(Reason::LostConnection));
}
// TODO: is this the correct error?
let _ = self.sender.send(UpdateMessage::Error(Error::LostConnection)).await;
@ -342,7 +342,7 @@ impl SupervisorHandle {
on_shutdown: oneshot::Sender<()>,
jid: Arc<Mutex<JID>>,
password: Arc<String>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> (WriteHandle, Self) {
let (command_sender, command_receiver) = mpsc::channel(20);
let (writer_error_sender, writer_error_receiver) = oneshot::channel();

View File

@ -18,13 +18,16 @@ use uuid::Uuid;
use crate::{
chat::{Body, Message},
db::Db,
error::{Error, IqError, MessageRecvError, PresenceError, ReadError, RosterError},
error::{Error, IqError, PresenceError, Reason, RecvMessageError},
presence::{Offline, Online, Presence, Show},
roster::Contact,
UpdateMessage,
};
use super::{write::WriteHandle, SupervisorCommand};
use super::{
write::{WriteHandle, WriteMessage},
SupervisorCommand,
};
pub struct Read {
control_receiver: mpsc::Receiver<ReadControl>,
@ -35,7 +38,7 @@ pub struct Read {
JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>,
db: Db,
update_sender: mpsc::Sender<UpdateMessage>,
@ -45,7 +48,7 @@ pub struct Read {
disconnecting: bool,
disconnect_timedout: oneshot::Receiver<()>,
// TODO: use proper stanza ids
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
}
impl Read {
@ -58,7 +61,7 @@ impl Read {
JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>,
db: Db,
update_sender: mpsc::Sender<UpdateMessage>,
@ -66,9 +69,9 @@ impl Read {
supervisor_control: mpsc::Sender<SupervisorCommand>,
write_handle: WriteHandle,
tasks: JoinSet<()>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> Self {
let (_send, recv) = oneshot::channel();
let (send, recv) = oneshot::channel();
Self {
control_receiver,
stream,
@ -159,7 +162,7 @@ impl Read {
// when it aborts, must clear iq map no matter what
let mut iqs = self.pending_iqs.lock().await;
for (_id, sender) in iqs.drain() {
let _ = sender.send(Err(ReadError::LostConnection));
let _ = sender.send(Err(Reason::LostConnection));
}
}
}
@ -175,7 +178,7 @@ async fn handle_stanza(
db: Db,
supervisor_control: mpsc::Sender<SupervisorCommand>,
write_handle: WriteHandle,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) {
match stanza {
Stanza::Message(stanza_message) => {
@ -204,9 +207,7 @@ async fn handle_stanza(
if let Err(e) = result {
tracing::error!("messagecreate");
let _ = update_sender
.send(UpdateMessage::Error(Error::MessageRecv(
MessageRecvError::MessageHistory(e.into()),
)))
.send(UpdateMessage::Error(Error::CacheUpdate(e.into())))
.await;
}
let _ = update_sender
@ -214,8 +215,8 @@ async fn handle_stanza(
.await;
} else {
let _ = update_sender
.send(UpdateMessage::Error(Error::MessageRecv(
MessageRecvError::MissingFrom,
.send(UpdateMessage::Error(Error::RecvMessage(
RecvMessageError::MissingFrom,
)))
.await;
}
@ -228,16 +229,9 @@ async fn handle_stanza(
stanza::client::presence::PresenceType::Error => {
// TODO: is there any other information that should go with the error? also MUST have an error, otherwise it's a different error. maybe it shoulnd't be an option.
let _ = update_sender
.send(UpdateMessage::Error(Error::Presence(
// TODO: ughhhhhhhhhhhhh these stanza errors should probably just have an option, and custom display
PresenceError::StanzaError(
presence
.errors
.first()
.cloned()
.expect("error MUST have error"),
),
)))
.send(UpdateMessage::Error(Error::Presence(PresenceError::Error(
Reason::Stanza(presence.errors.first().cloned()),
))))
.await;
}
// should not happen (error to server)
@ -335,8 +329,8 @@ async fn handle_stanza(
let contact: Contact = item.into();
if let Err(e) = db.upsert_contact(contact.clone()).await {
let _ = update_sender
.send(UpdateMessage::Error(Error::Roster(
RosterError::Cache(e.into()),
.send(UpdateMessage::Error(Error::CacheUpdate(
e.into(),
)))
.await;
}
@ -387,7 +381,7 @@ pub enum ReadControl {
JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>,
),
}
@ -420,13 +414,13 @@ impl ReadControlHandle {
JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>,
db: Db,
sender: mpsc::Sender<UpdateMessage>,
supervisor_control: mpsc::Sender<SupervisorCommand>,
jabber_write: WriteHandle,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20);
@ -457,14 +451,14 @@ impl ReadControlHandle {
JoinSet<()>,
mpsc::Sender<SupervisorCommand>,
WriteHandle,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
)>,
db: Db,
sender: mpsc::Sender<UpdateMessage>,
supervisor_control: mpsc::Sender<SupervisorCommand>,
jabber_write: WriteHandle,
tasks: JoinSet<()>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20);

View File

@ -7,7 +7,7 @@ use tokio::{
task::JoinHandle,
};
use crate::error::WriteError;
use crate::error::{Error, Reason};
// actor that receives jabber stanzas to write, and if there is an error, sends a message back to the supervisor then aborts, so the supervisor can spawn a new stream.
pub struct Write {
@ -17,10 +17,9 @@ pub struct Write {
on_crash: oneshot::Sender<(WriteMessage, mpsc::Receiver<WriteMessage>)>,
}
#[derive(Debug)]
pub struct WriteMessage {
pub stanza: Stanza,
pub respond_to: oneshot::Sender<Result<(), WriteError>>,
pub respond_to: oneshot::Sender<Result<(), Reason>>,
}
pub enum WriteControl {
@ -85,9 +84,9 @@ impl Write {
Err(e) => match &e {
peanuts::Error::ReadError(_error) => {
// if connection lost during disconnection, just send lost connection error to the write requests
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
let _ = msg.respond_to.send(Err(Reason::LostConnection));
while let Some(msg) = self.stanza_receiver.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
let _ = msg.respond_to.send(Err(Reason::LostConnection));
}
break;
}
@ -141,16 +140,16 @@ pub struct WriteHandle {
}
impl WriteHandle {
pub async fn write(&self, stanza: Stanza) -> Result<(), WriteError> {
pub async fn write(&self, stanza: Stanza) -> Result<(), Reason> {
let (send, recv) = oneshot::channel();
self.send(WriteMessage {
stanza,
respond_to: send,
})
.await
.map_err(|e| WriteError::Actor(e.into()))?;
.map_err(|_| Reason::ChannelSend)?;
// TODO: timeout
recv.await.map_err(|e| WriteError::Actor(e.into()))?
recv.await?
}
}

View File

@ -1,158 +1,138 @@
use std::sync::Arc;
use stanza::client::Stanza;
use thiserror::Error;
use tokio::sync::{mpsc::error::SendError, oneshot::error::RecvError};
use tokio::sync::oneshot::{self};
#[derive(Debug, Error, Clone)]
#[derive(Debug)]
pub enum Error {
#[error("already connected")]
AlreadyConnected,
// TODO: change to Connecting(ConnectingError)
#[error("connecting: {0}")]
Connecting(#[from] ConnectionError),
#[error("presence: {0}")]
Presence(#[from] PresenceError),
#[error("set status: {0}")]
SetStatus(#[from] StatusError),
// TODO: have different ones for get/update/set
#[error("roster: {0}")]
Roster(RosterError),
#[error("stream error: {0}")]
Stream(#[from] stanza::stream::Error),
#[error("message send error: {0}")]
MessageSend(MessageSendError),
#[error("message receive error: {0}")]
MessageRecv(MessageRecvError),
#[error("already disconnected")]
Connection(ConnectionError),
Presence(PresenceError),
SetStatus(Reason),
Roster(Reason),
Stream(stanza::stream::Error),
SendMessage(Reason),
RecvMessage(RecvMessageError),
AlreadyDisconnected,
#[error("lost connection")]
LostConnection,
// TODO: Display for Content
#[error("received unrecognized/unsupported content: {0:?}")]
// TODO: should all cache update errors include the context?
CacheUpdate(Reason),
UnrecognizedContent(peanuts::element::Content),
#[error("iq receive error: {0}")]
Iq(IqError),
#[error("disconnected")]
Disconnected,
Cloned,
}
#[derive(Debug, Error, Clone)]
pub enum MessageSendError {
#[error("could not add to message history: {0}")]
MessageHistory(#[from] DatabaseError),
// TODO: this is horrifying, maybe just use tracing to forward error events???
impl Clone for Error {
fn clone(&self) -> Self {
Error::Cloned
}
}
#[derive(Debug, Error, Clone)]
#[derive(Debug)]
pub enum PresenceError {
#[error("unsupported")]
Error(Reason),
Unsupported,
#[error("missing from")]
MissingFrom,
#[error("stanza error: {0}")]
StanzaError(#[from] stanza::client::error::Error),
}
#[derive(Debug, Error, Clone)]
// TODO: should probably have all iq query related errors here, including read, write, stanza error, etc.
#[derive(Debug)]
pub enum IqError {
#[error("no iq with id matching `{0}`")]
NoMatchingId(String),
}
#[derive(Debug, Error, Clone)]
pub enum MessageRecvError {
#[error("could not add to message history: {0}")]
MessageHistory(#[from] DatabaseError),
#[error("missing from")]
#[derive(Debug)]
pub enum RecvMessageError {
MissingFrom,
}
#[derive(Debug, Clone, Error)]
#[derive(Debug, Clone)]
pub enum ConnectionError {
#[error("connection failed: {0}")]
ConnectionFailed(#[from] jabber::Error),
#[error("failed roster retreival: {0}")]
RosterRetreival(#[from] RosterError),
#[error("failed to send available presence: {0}")]
SendPresence(#[from] WriteError),
#[error("cached status: {0}")]
StatusCacheError(#[from] DatabaseError),
ConnectionFailed(Reason),
RosterRetreival(Reason),
SendPresence(Reason),
NoCachedStatus(Reason),
}
#[derive(Debug, Error, Clone)]
pub enum RosterError {
#[error("cache: {0}")]
Cache(#[from] DatabaseError),
#[error("stream write: {0}")]
Write(#[from] WriteError),
// TODO: display for stanza, to show as xml, same for read error types.
#[error("unexpected reply: {0:?}")]
UnexpectedStanza(Stanza),
#[error("stream read: {0}")]
Read(#[from] ReadError),
#[error("stanza error: {0}")]
StanzaError(#[from] stanza::client::error::Error),
}
#[derive(Debug)]
pub struct RosterError(pub Reason);
#[derive(Debug, Error, Clone)]
#[error("database error: {0}")]
pub struct DatabaseError(Arc<sqlx::Error>);
impl From<sqlx::Error> for DatabaseError {
fn from(e: sqlx::Error) -> Self {
Self(Arc::new(e))
impl From<RosterError> for Error {
fn from(e: RosterError) -> Self {
Self::Roster(e.0)
}
}
#[derive(Debug, Error, Clone)]
pub enum StatusError {
#[error("cache: {0}")]
Cache(#[from] DatabaseError),
#[error("stream write: {0}")]
Write(#[from] WriteError),
impl From<RosterError> for ConnectionError {
fn from(e: RosterError) -> Self {
Self::RosterRetreival(e.0)
}
}
#[derive(Debug, Error, Clone)]
pub enum WriteError {
#[error("xml: {0}")]
XML(#[from] peanuts::Error),
#[error("lost connection")]
LostConnection,
// TODO: should this be in writeerror or separate?
#[error("actor: {0}")]
Actor(#[from] ActorError),
#[error("disconnected")]
Disconnected,
pub struct StatusError(pub Reason);
impl From<StatusError> for Error {
fn from(e: StatusError) -> Self {
Error::SetStatus(e.0)
}
}
// TODO: separate peanuts read and write error?
#[derive(Debug, Error, Clone)]
pub enum ReadError {
#[error("xml: {0}")]
XML(#[from] peanuts::Error),
#[error("lost connection")]
LostConnection,
impl From<StatusError> for ConnectionError {
fn from(e: StatusError) -> Self {
Self::SendPresence(e.0)
}
}
#[derive(Debug, Error, Clone)]
pub enum ActorError {
#[error("receive timed out")]
#[derive(Debug)]
pub enum Reason {
// TODO: organisastion of error into internal error thing
Timeout,
#[error("could not send message to actor, channel closed")]
Send,
#[error("could not receive message from actor, channel closed")]
Receive,
Stream(stanza::stream_error::Error),
Stanza(Option<stanza::client::error::Error>),
Jabber(jabber::Error),
XML(peanuts::Error),
SQL(sqlx::Error),
// JID(jid::ParseError),
LostConnection,
OneshotRecv(oneshot::error::RecvError),
UnexpectedStanza(Stanza),
Disconnected,
ChannelSend,
Cloned,
}
impl<T> From<SendError<T>> for ActorError {
fn from(_e: SendError<T>) -> Self {
Self::Send
// TODO: same here
impl Clone for Reason {
fn clone(&self) -> Self {
Reason::Cloned
}
}
impl From<RecvError> for ActorError {
fn from(_e: RecvError) -> Self {
Self::Receive
impl From<oneshot::error::RecvError> for Reason {
fn from(e: oneshot::error::RecvError) -> Reason {
Self::OneshotRecv(e)
}
}
impl From<peanuts::Error> for Reason {
fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
// impl From<jid::ParseError> for Reason {
// fn from(e: jid::ParseError) -> Self {
// Self::JID(e)
// }
// }
impl From<sqlx::Error> for Reason {
fn from(e: sqlx::Error) -> Self {
Self::SQL(e)
}
}
impl From<jabber::Error> for Reason {
fn from(e: jabber::Error) -> Self {
Self::Jabber(e)
}
}

View File

@ -7,8 +7,7 @@ use std::{
use chat::{Body, Chat, Message};
use connection::{write::WriteMessage, SupervisorSender};
use db::Db;
use error::{ConnectionError, DatabaseError, ReadError, RosterError, StatusError, WriteError};
use futures::{future::Fuse, FutureExt};
use error::{ConnectionError, Reason, RosterError, StatusError};
use jabber::JID;
use presence::{Offline, Online, Presence};
use roster::{Contact, ContactUpdate};
@ -21,7 +20,7 @@ use tokio::{
sync::{mpsc, oneshot, Mutex},
task::JoinSet,
};
use tracing::{debug, info};
use tracing::{debug, info, Instrument};
use user::User;
use uuid::Uuid;
@ -44,11 +43,11 @@ pub struct Luz {
// TODO: use a dyn passwordprovider trait to avoid storing password in memory
password: Arc<String>,
connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
db: Db,
sender: mpsc::Sender<UpdateMessage>,
/// if connection was shut down due to e.g. server shutdown, supervisor must be able to mark client as disconnected
connection_supervisor_shutdown: Fuse<oneshot::Receiver<()>>,
connection_supervisor_shutdown: oneshot::Receiver<()>,
// TODO: will need to have an auto reconnect state as well (e.g. in case server shut down, to try and reconnect later)
// TODO: will grow forever at this point, maybe not required as tasks will naturally shut down anyway?
tasks: JoinSet<()>,
@ -61,7 +60,7 @@ impl Luz {
jid: Arc<Mutex<JID>>,
password: String,
connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>,
connection_supervisor_shutdown: Fuse<oneshot::Receiver<()>>,
connection_supervisor_shutdown: oneshot::Receiver<()>,
db: SqlitePool,
sender: mpsc::Sender<UpdateMessage>,
) -> Self {
@ -83,8 +82,9 @@ impl Luz {
loop {
let msg = tokio::select! {
// this is okay, as when created the supervisor (and connection) doesn't exist, but a bit messy
// THIS IS NOT OKAY LOLLLL - apparently fusing is the best option???
// THIS IS NOT OKAY LOLLLL
_ = &mut self.connection_supervisor_shutdown => {
info!("got this");
*self.connected.lock().await = None;
continue;
}
@ -130,7 +130,6 @@ impl Luz {
self.password.clone(),
self.pending_iqs.clone(),
);
let shutdown_recv = shutdown_recv.fuse();
self.connection_supervisor_shutdown = shutdown_recv;
// TODO: get roster and send initial presence
let (send, recv) = oneshot::channel();
@ -159,8 +158,8 @@ impl Luz {
let _ = self
.sender
.send(UpdateMessage::Error(
Error::Connecting(
ConnectionError::StatusCacheError(
Error::Connection(
ConnectionError::NoCachedStatus(
e.into(),
),
),
@ -170,11 +169,7 @@ impl Luz {
}
};
let (send, recv) = oneshot::channel();
CommandMessage::SendPresence(
None,
Presence::Online(online.clone()),
send,
)
CommandMessage::SetStatus(online.clone(), send)
.handle_online(
writer.clone(),
supervisor.sender(),
@ -202,13 +197,13 @@ impl Luz {
let _ = self
.sender
.send(UpdateMessage::Error(
Error::Connecting(e.into()),
Error::Connection(e.into()),
))
.await;
}
},
Err(e) => {
let _ = self.sender.send(UpdateMessage::Error(Error::Connecting(ConnectionError::SendPresence(WriteError::Actor(e.into()))))).await;
let _ = self.sender.send(UpdateMessage::Error(Error::Connection(ConnectionError::SendPresence(e.into())))).await;
}
}
}
@ -216,7 +211,7 @@ impl Luz {
let _ = self
.sender
.send(UpdateMessage::Error(
Error::Connecting(e.into()),
Error::Connection(e.into()),
))
.await;
}
@ -225,12 +220,8 @@ impl Luz {
Err(e) => {
let _ = self
.sender
.send(UpdateMessage::Error(Error::Connecting(
ConnectionError::RosterRetreival(
RosterError::Write(WriteError::Actor(
e.into(),
)),
),
.send(UpdateMessage::Error(Error::Connection(
ConnectionError::RosterRetreival(e.into()),
)))
.await;
}
@ -238,7 +229,7 @@ impl Luz {
}
Err(e) => {
let _ =
self.sender.send(UpdateMessage::Error(Error::Connecting(
self.sender.send(UpdateMessage::Error(Error::Connection(
ConnectionError::ConnectionFailed(e.into()),
)));
}
@ -246,7 +237,7 @@ impl Luz {
}
};
}
CommandMessage::Disconnect(offline) => {
CommandMessage::Disconnect(_offline) => {
match self.connected.lock().await.as_mut() {
None => {
let _ = self
@ -256,19 +247,15 @@ impl Luz {
}
mut c => {
// TODO: send unavailable presence
if let Some((write_handle, supervisor_handle)) = c.take() {
let offline_presence: stanza::client::presence::Presence =
offline.clone().into();
let stanza = Stanza::Presence(offline_presence);
// TODO: timeout and error check
write_handle.write(stanza).await;
if let Some((_write_handle, supervisor_handle)) = c.take() {
let _ = supervisor_handle.send(SupervisorCommand::Disconnect).await;
let _ = self.sender.send(UpdateMessage::Offline(offline)).await;
c = None;
} else {
unreachable!()
};
}
}
info!("lock released")
}
_ => {
match self.connected.lock().await.as_ref() {
@ -294,7 +281,7 @@ impl Luz {
impl CommandMessage {
pub async fn handle_offline(
self,
mut self,
jid: Arc<Mutex<JID>>,
db: Db,
update_sender: mpsc::Sender<UpdateMessage>,
@ -309,7 +296,7 @@ impl CommandMessage {
let _ = sender.send(Ok(roster));
}
Err(e) => {
let _ = sender.send(Err(RosterError::Cache(e.into())));
let _ = sender.send(Err(RosterError(e.into())));
}
}
}
@ -339,48 +326,45 @@ impl CommandMessage {
}
// TODO: offline queue to modify roster
CommandMessage::AddContact(jid, sender) => {
sender.send(Err(RosterError::Write(WriteError::Disconnected)));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::BuddyRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::SubscriptionRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::AcceptBuddyRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::AcceptSubscriptionRequest(jid, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::UnsubscribeFromContact(jid, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::UnsubscribeContact(jid, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::UnfriendContact(jid, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::DeleteContact(jid, sender) => {
sender.send(Err(RosterError::Write(WriteError::Disconnected)));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::UpdateContact(jid, contact_update, sender) => {
sender.send(Err(RosterError::Write(WriteError::Disconnected)));
sender.send(Err(Reason::Disconnected));
}
CommandMessage::SetStatus(online, sender) => {
let result = db
.upsert_cached_status(online)
.await
.map_err(|e| StatusError::Cache(e.into()));
.map_err(|e| StatusError(e.into()));
sender.send(result);
}
// TODO: offline message queue
CommandMessage::SendMessage(jid, body, sender) => {
sender.send(Err(WriteError::Disconnected));
}
CommandMessage::SendPresence(jid, presence, sender) => {
sender.send(Err(WriteError::Disconnected));
sender.send(Err(Reason::Disconnected));
}
}
}
@ -393,7 +377,7 @@ impl CommandMessage {
client_jid: Arc<Mutex<JID>>,
db: Db,
update_sender: mpsc::Sender<UpdateMessage>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, ReadError>>>>>,
pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Reason>>>>>,
) {
match self {
CommandMessage::Connect => unreachable!(),
@ -435,12 +419,11 @@ impl CommandMessage {
Ok(Ok(())) => info!("roster request sent"),
Ok(Err(e)) => {
// TODO: log errors if fail to send
let _ = result_sender.send(Err(RosterError::Write(e.into())));
let _ = result_sender.send(Err(RosterError(e.into())));
return;
}
Err(e) => {
let _ = result_sender
.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
let _ = result_sender.send(Err(RosterError(e.into())));
return;
}
};
@ -460,41 +443,23 @@ impl CommandMessage {
items.into_iter().map(|item| item.into()).collect();
if let Err(e) = db.replace_cached_roster(contacts.clone()).await {
update_sender
.send(UpdateMessage::Error(Error::Roster(RosterError::Cache(
e.into(),
))))
.send(UpdateMessage::Error(Error::CacheUpdate(e.into())))
.await;
};
result_sender.send(Ok(contacts));
return;
}
ref s @ Stanza::Iq(Iq {
from: _,
ref id,
to: _,
r#type,
lang: _,
query: _,
ref errors,
}) if *id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() {
result_sender.send(Err(RosterError::StanzaError(error.clone())));
} else {
result_sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
}
return;
}
s => {
result_sender.send(Err(RosterError::UnexpectedStanza(s)));
result_sender.send(Err(RosterError(Reason::UnexpectedStanza(s))));
return;
}
},
Ok(Err(e)) => {
result_sender.send(Err(RosterError::Read(e)));
result_sender.send(Err(RosterError(e.into())));
return;
}
Err(e) => {
result_sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
result_sender.send(Err(RosterError(e.into())));
return;
}
}
@ -555,8 +520,8 @@ impl CommandMessage {
}
// TODO: write_handle send helper function
let result = write_handle.write(set_stanza).await;
if let Err(e) = result {
sender.send(Err(RosterError::Write(e)));
if let Err(_) = result {
sender.send(result);
return;
}
let iq_result = recv.await;
@ -575,24 +540,24 @@ impl CommandMessage {
sender.send(Ok(()));
return;
}
ref s @ Stanza::Iq(Iq {
Stanza::Iq(Iq {
from: _,
ref id,
id,
to: _,
r#type,
lang: _,
query: _,
ref errors,
}) if *id == iq_id && r#type == IqType::Error => {
errors,
}) if id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() {
sender.send(Err(RosterError::StanzaError(error.clone())));
sender.send(Err(Reason::Stanza(Some(error.clone()))));
} else {
sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
sender.send(Err(Reason::Stanza(None)));
}
return;
}
s => {
sender.send(Err(RosterError::UnexpectedStanza(s)));
sender.send(Err(Reason::UnexpectedStanza(s)));
return;
}
},
@ -602,7 +567,7 @@ impl CommandMessage {
}
},
Err(e) => {
sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
sender.send(Err(e.into()));
return;
}
}
@ -800,8 +765,8 @@ impl CommandMessage {
pending_iqs.lock().await.insert(iq_id.clone(), send);
}
let result = write_handle.write(set_stanza).await;
if let Err(e) = result {
sender.send(Err(RosterError::Write(e)));
if let Err(_) = result {
sender.send(result);
return;
}
let iq_result = recv.await;
@ -820,24 +785,24 @@ impl CommandMessage {
sender.send(Ok(()));
return;
}
ref s @ Stanza::Iq(Iq {
Stanza::Iq(Iq {
from: _,
ref id,
id,
to: _,
r#type,
lang: _,
query: _,
ref errors,
}) if *id == iq_id && r#type == IqType::Error => {
errors,
}) if id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() {
sender.send(Err(RosterError::StanzaError(error.clone())));
sender.send(Err(Reason::Stanza(Some(error.clone()))));
} else {
sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
sender.send(Err(Reason::Stanza(None)));
}
return;
}
s => {
sender.send(Err(RosterError::UnexpectedStanza(s)));
sender.send(Err(Reason::UnexpectedStanza(s)));
return;
}
},
@ -847,7 +812,7 @@ impl CommandMessage {
}
},
Err(e) => {
sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
sender.send(Err(e.into()));
return;
}
}
@ -888,8 +853,8 @@ impl CommandMessage {
pending_iqs.lock().await.insert(iq_id.clone(), send);
}
let result = write_handle.write(set_stanza).await;
if let Err(e) = result {
sender.send(Err(RosterError::Write(e)));
if let Err(_) = result {
sender.send(result);
return;
}
let iq_result = recv.await;
@ -908,24 +873,24 @@ impl CommandMessage {
sender.send(Ok(()));
return;
}
ref s @ Stanza::Iq(Iq {
Stanza::Iq(Iq {
from: _,
ref id,
id,
to: _,
r#type,
lang: _,
query: _,
ref errors,
}) if *id == iq_id && r#type == IqType::Error => {
errors,
}) if id == iq_id && r#type == IqType::Error => {
if let Some(error) = errors.first() {
sender.send(Err(RosterError::StanzaError(error.clone())));
sender.send(Err(Reason::Stanza(Some(error.clone()))));
} else {
sender.send(Err(RosterError::UnexpectedStanza(s.clone())));
sender.send(Err(Reason::Stanza(None)));
}
return;
}
s => {
sender.send(Err(RosterError::UnexpectedStanza(s)));
sender.send(Err(Reason::UnexpectedStanza(s)));
return;
}
},
@ -935,7 +900,7 @@ impl CommandMessage {
}
},
Err(e) => {
sender.send(Err(RosterError::Write(WriteError::Actor(e.into()))));
sender.send(Err(e.into()));
return;
}
}
@ -944,16 +909,13 @@ impl CommandMessage {
let result = db.upsert_cached_status(online.clone()).await;
if let Err(e) = result {
let _ = update_sender
.send(UpdateMessage::Error(Error::SetStatus(StatusError::Cache(
e.into(),
))))
.send(UpdateMessage::Error(Error::CacheUpdate(e.into())))
.await;
}
let result = write_handle
.write(Stanza::Presence(online.into()))
.await
.map_err(|e| StatusError::Write(e));
// .map_err(|e| StatusError::Write(e));
.map_err(|e| StatusError(e));
let _ = sender.send(result);
}
// TODO: offline message queue
@ -989,9 +951,7 @@ impl CommandMessage {
};
if let Err(e) = db.create_message(message, jid).await.map_err(|e| e.into())
{
let _ = update_sender.send(UpdateMessage::Error(Error::MessageSend(
error::MessageSendError::MessageHistory(e),
)));
let _ = update_sender.send(UpdateMessage::Error(Error::CacheUpdate(e)));
}
let _ = sender.send(Ok(()));
}
@ -1000,15 +960,6 @@ impl CommandMessage {
}
}
}
CommandMessage::SendPresence(jid, presence, sender) => {
let mut presence: stanza::client::presence::Presence = presence.into();
if let Some(jid) = jid {
presence.to = Some(jid);
};
let result = write_handle.write(Stanza::Presence(presence)).await;
// .map_err(|e| StatusError::Write(e));
let _ = sender.send(result);
}
}
}
}
@ -1043,18 +994,16 @@ impl DerefMut for LuzHandle {
}
impl LuzHandle {
// TODO: database creation separate
pub async fn new(
jid: JID,
password: String,
db: &str,
) -> Result<(Self, mpsc::Receiver<UpdateMessage>), DatabaseError> {
) -> Result<(Self, mpsc::Receiver<UpdateMessage>), Reason> {
let db = SqlitePool::connect(db).await?;
let (command_sender, command_receiver) = mpsc::channel(20);
let (update_sender, update_receiver) = mpsc::channel(20);
// might be bad, first supervisor shutdown notification oneshot is never used (disgusting)
let (sup_send, sup_recv) = oneshot::channel();
let mut sup_recv = sup_recv.fuse();
let actor = Luz::new(
command_sender.clone(),
@ -1075,59 +1024,8 @@ impl LuzHandle {
update_receiver,
))
}
pub async fn connect(&self) {
self.send(CommandMessage::Connect).await;
}
pub async fn disconnect(&self, offline: Offline) {
self.send(CommandMessage::Disconnect(offline)).await;
}
// pub async fn get_roster(&self) -> Result<Vec<Contact>, RosterError> {
// let (send, recv) = oneshot::channel();
// self.send(CommandMessage::GetRoster(send)).await.map_err(|e| RosterError::)?;
// Ok(recv.await?)
// }
// pub async fn get_chats(&self) -> Result<Vec<Chat>, Error> {}
// pub async fn get_chat(&self, jid: JID) -> Result<Chat, Error> {}
// pub async fn get_messages(&self, jid: JID) -> Result<Vec<Message>, Error> {}
// pub async fn delete_chat(&self, jid: JID) -> Result<(), Error> {}
// pub async fn delete_message(&self, id: Uuid) -> Result<(), Error> {}
// pub async fn get_user(&self, jid: JID) -> Result<User, Error> {}
// pub async fn add_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn buddy_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn subscription_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn accept_buddy_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn accept_subscription_request(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unsubscribe_from_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unsubscribe_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn unfriend_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn delete_contact(&self, jid: JID) -> Result<(), Error> {}
// pub async fn update_contact(&self, jid: JID, update: ContactUpdate) -> Result<(), Error> {}
// pub async fn set_status(&self, online: Online) -> Result<(), Error> {}
// pub async fn send_message(&self, jid: JID, body: Body) -> Result<(), Error> {}
}
// TODO: generate methods for each with a macro
pub enum CommandMessage {
// TODO: login invisible xep-0186
/// connect to XMPP chat server. gets roster and publishes initial presence.
@ -1138,51 +1036,46 @@ pub enum CommandMessage {
GetRoster(oneshot::Sender<Result<Vec<Contact>, RosterError>>),
/// get all chats. chat will include 10 messages in their message Vec (enough for chat previews)
// TODO: paging and filtering
GetChats(oneshot::Sender<Result<Vec<Chat>, DatabaseError>>),
GetChats(oneshot::Sender<Result<Vec<Chat>, Reason>>),
/// get a specific chat by jid
GetChat(JID, oneshot::Sender<Result<Chat, DatabaseError>>),
GetChat(JID, oneshot::Sender<Result<Chat, Reason>>),
/// get message history for chat (does appropriate mam things)
// TODO: paging and filtering
GetMessages(JID, oneshot::Sender<Result<Vec<Message>, DatabaseError>>),
GetMessages(JID, oneshot::Sender<Result<Vec<Message>, Reason>>),
/// delete a chat from your chat history, along with all the corresponding messages
DeleteChat(JID, oneshot::Sender<Result<(), DatabaseError>>),
DeleteChat(JID, oneshot::Sender<Result<(), Reason>>),
/// delete a message from your chat history
DeleteMessage(Uuid, oneshot::Sender<Result<(), DatabaseError>>),
DeleteMessage(Uuid, oneshot::Sender<Result<(), Reason>>),
/// get a user from your users database
GetUser(JID, oneshot::Sender<Result<User, DatabaseError>>),
GetUser(JID, oneshot::Sender<Result<User, Reason>>),
/// add a contact to your roster, with a status of none, no subscriptions.
AddContact(JID, oneshot::Sender<Result<(), RosterError>>),
// TODO: for all these, consider returning with oneshot::Sender<Result<(), Error>>
AddContact(JID, oneshot::Sender<Result<(), Reason>>),
/// send a friend request i.e. a subscription request with a subscription pre-approval. if not already added to roster server adds to roster.
BuddyRequest(JID, oneshot::Sender<Result<(), WriteError>>),
BuddyRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// send a subscription request, without pre-approval. if not already added to roster server adds to roster.
SubscriptionRequest(JID, oneshot::Sender<Result<(), WriteError>>),
SubscriptionRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// accept a friend request by accepting a pending subscription and sending a subscription request back. if not already added to roster adds to roster.
AcceptBuddyRequest(JID, oneshot::Sender<Result<(), WriteError>>),
AcceptBuddyRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// accept a pending subscription and doesn't send a subscription request back. if not already added to roster adds to roster.
AcceptSubscriptionRequest(JID, oneshot::Sender<Result<(), WriteError>>),
AcceptSubscriptionRequest(JID, oneshot::Sender<Result<(), Reason>>),
/// unsubscribe to a contact, but don't remove their subscription.
UnsubscribeFromContact(JID, oneshot::Sender<Result<(), WriteError>>),
UnsubscribeFromContact(JID, oneshot::Sender<Result<(), Reason>>),
/// stop a contact from being subscribed, but stay subscribed to the contact.
UnsubscribeContact(JID, oneshot::Sender<Result<(), WriteError>>),
UnsubscribeContact(JID, oneshot::Sender<Result<(), Reason>>),
/// remove subscriptions to and from contact, but keep in roster.
UnfriendContact(JID, oneshot::Sender<Result<(), WriteError>>),
UnfriendContact(JID, oneshot::Sender<Result<(), Reason>>),
/// remove a contact from the contact list. will remove subscriptions if not already done then delete contact from roster.
DeleteContact(JID, oneshot::Sender<Result<(), RosterError>>),
DeleteContact(JID, oneshot::Sender<Result<(), Reason>>),
/// update contact. contact details will be overwritten with the contents of the contactupdate struct.
UpdateContact(JID, ContactUpdate, oneshot::Sender<Result<(), RosterError>>),
UpdateContact(JID, ContactUpdate, oneshot::Sender<Result<(), Reason>>),
/// set online status. if disconnected, will be cached so when client connects, will be sent as the initial presence.
SetStatus(Online, oneshot::Sender<Result<(), StatusError>>),
/// send presence stanza
SendPresence(
Option<JID>,
Presence,
oneshot::Sender<Result<(), WriteError>>,
),
/// send a directed presence (usually to a non-contact).
// TODO: should probably make it so people can add non-contact auto presence sharing in the client (most likely through setting an internal setting)
/// send a message to a jid (any kind of jid that can receive a message, e.g. a user or a
/// chatroom). if disconnected, will be cached so when client connects, message will be sent.
SendMessage(JID, Body, oneshot::Sender<Result<(), WriteError>>),
SendMessage(JID, Body, oneshot::Sender<Result<(), Reason>>),
}
#[derive(Debug, Clone)]

View File

@ -12,13 +12,9 @@ use tracing::info;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let (luz, mut recv) = LuzHandle::new(
"test@blos.sm".try_into().unwrap(),
"slayed".to_string(),
"./luz.db",
)
.await
.unwrap();
let db = SqlitePool::connect("./luz.db").await.unwrap();
let (luz, mut recv) =
LuzHandle::new("test@blos.sm".try_into().unwrap(), "slayed".to_string(), db);
tokio::spawn(async move {
while let Some(msg) = recv.recv().await {

View File

@ -91,31 +91,3 @@ impl From<Online> for stanza::client::presence::Presence {
}
}
}
impl From<Offline> for stanza::client::presence::Presence {
fn from(value: Offline) -> Self {
Self {
from: None,
id: None,
to: None,
r#type: Some(stanza::client::presence::PresenceType::Unavailable),
lang: None,
show: None,
status: value.status.map(|status| stanza::client::presence::Status {
lang: None,
status: String1024(status),
}),
priority: None,
errors: Vec::new(),
}
}
}
impl From<Presence> for stanza::client::presence::Presence {
fn from(value: Presence) -> Self {
match value {
Presence::Online(online) => online.into(),
Presence::Offline(offline) => offline.into(),
}
}
}

View File

@ -6,4 +6,3 @@ edition = "2021"
[dependencies]
peanuts = { version = "0.1.0", path = "../../peanuts" }
jid = { version = "0.1.0", path = "../jid" }
thiserror = "2.0.11"

View File

@ -13,8 +13,8 @@ pub struct Bind {
impl FromElement for Bind {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("bind")?;
element.check_name(XMLNS)?;
element.check_name("bind");
element.check_name(XMLNS);
let r#type = element.pop_child_opt()?;
@ -61,8 +61,8 @@ pub struct FullJidType(pub JID);
impl FromElement for FullJidType {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("jid")?;
element.check_namespace(XMLNS)?;
element.check_name("jid");
element.check_namespace(XMLNS);
let jid = element.pop_value()?;

View File

@ -1,16 +1,14 @@
use std::fmt::Display;
use std::str::FromStr;
use peanuts::element::{FromElement, IntoElement};
use peanuts::{DeserializeError, Element};
use thiserror::Error;
use crate::stanza_error::Error as StanzaError;
use crate::stanza_error::Text;
use super::XMLNS;
#[derive(Clone, Debug, Error)]
#[derive(Clone, Debug)]
pub struct Error {
by: Option<String>,
r#type: ErrorType,
@ -19,22 +17,6 @@ pub struct Error {
text: Option<Text>,
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}, {}", self.r#type, self.error)?;
if let Some(text) = &self.text {
if let Some(text) = &text.text {
write!(f, ": {}", text)?;
}
}
if let Some(by) = &self.by {
write!(f, " (error returned by `{}`)", by)?;
}
Ok(())
}
}
impl FromElement for Error {
fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("error")?;
@ -73,18 +55,6 @@ pub enum ErrorType {
Wait,
}
impl Display for ErrorType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorType::Auth => f.write_str("auth"),
ErrorType::Cancel => f.write_str("cancel"),
ErrorType::Continue => f.write_str("continue"),
ErrorType::Modify => f.write_str("modify"),
ErrorType::Wait => f.write_str("wait"),
}
}
}
impl FromStr for ErrorType {
type Err = DeserializeError;
@ -99,3 +69,15 @@ impl FromStr for ErrorType {
}
}
}
impl ToString for ErrorType {
fn to_string(&self) -> String {
match self {
ErrorType::Auth => "auth".to_string(),
ErrorType::Cancel => "cancel".to_string(),
ErrorType::Continue => "continue".to_string(),
ErrorType::Modify => "modify".to_string(),
ErrorType::Wait => "wait".to_string(),
}
}
}

View File

@ -15,7 +15,7 @@ use crate::{
use super::XMLNS;
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct Iq {
pub from: Option<JID>,
pub id: String,

View File

@ -8,7 +8,7 @@ use peanuts::{
use super::XMLNS;
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct Message {
pub from: Option<JID>,
pub id: Option<String>,

View File

@ -1,7 +1,7 @@
use iq::Iq;
use message::Message;
use peanuts::{
element::{Content, ContentBuilder, FromContent, FromElement, IntoContent},
element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement},
DeserializeError,
};
use presence::Presence;
@ -16,7 +16,7 @@ pub mod presence;
pub const XMLNS: &str = "jabber:client";
/// TODO: End tag
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum Stanza {
Message(Message),
Presence(Presence),

View File

@ -8,7 +8,7 @@ use peanuts::{
use super::{error::Error, XMLNS};
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct Presence {
pub from: Option<JID>,
pub id: Option<String>,

View File

@ -1,10 +1,9 @@
use std::{fmt::Display, ops::Deref};
use std::ops::Deref;
use peanuts::{
element::{FromElement, IntoElement},
DeserializeError, Element,
};
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
@ -169,48 +168,12 @@ impl IntoElement for Response {
}
}
#[derive(Error, Debug, Clone)]
#[derive(Debug)]
pub struct Failure {
r#type: Option<FailureType>,
text: Option<Text>,
}
impl Display for Failure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut had_type = false;
let mut had_text = false;
if let Some(r#type) = &self.r#type {
had_type = true;
match r#type {
FailureType::Aborted => f.write_str("aborted"),
FailureType::AccountDisabled => f.write_str("account disabled"),
FailureType::CredentialsExpired => f.write_str("credentials expired"),
FailureType::EncryptionRequired => f.write_str("encryption required"),
FailureType::IncorrectEncoding => f.write_str("incorrect encoding"),
FailureType::InvalidAuthzid => f.write_str("invalid authzid"),
FailureType::InvalidMechanism => f.write_str("invalid mechanism"),
FailureType::MalformedRequest => f.write_str("malformed request"),
FailureType::MechanismTooWeak => f.write_str("mechanism too weak"),
FailureType::NotAuthorized => f.write_str("not authorized"),
FailureType::TemporaryAuthFailure => f.write_str("temporary auth failure"),
}?;
}
if let Some(text) = &self.text {
if let Some(text) = &text.text {
if had_type {
f.write_str(": ")?;
}
f.write_str(text)?;
had_text = true;
}
}
if !had_type && !had_text {
f.write_str("failure")?;
}
Ok(())
}
}
impl FromElement for Failure {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("failure")?;
@ -223,29 +186,18 @@ impl FromElement for Failure {
}
}
#[derive(Error, Debug, Clone)]
#[derive(Debug)]
pub enum FailureType {
#[error("aborted")]
Aborted,
#[error("account disabled")]
AccountDisabled,
#[error("credentials expired")]
CredentialsExpired,
#[error("encryption required")]
EncryptionRequired,
#[error("incorrect encoding")]
IncorrectEncoding,
#[error("invalid authzid")]
InvalidAuthzid,
#[error("invalid mechanism")]
InvalidMechanism,
#[error("malformed request")]
MalformedRequest,
#[error("mechanism too weak")]
MechanismTooWeak,
#[error("not authorized")]
NotAuthorized,
#[error("temporary auth failure")]
TemporaryAuthFailure,
}
@ -268,9 +220,8 @@ impl FromElement for FailureType {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct Text {
#[allow(dead_code)]
lang: Option<String>,
text: Option<String>,
}

View File

@ -4,55 +4,32 @@ use peanuts::{
element::{FromElement, IntoElement},
Element, XML_NS,
};
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-stanzas";
#[derive(Error, Clone, Debug)]
#[derive(Clone, Debug)]
pub enum Error {
#[error("bad request")]
BadRequest,
#[error("conflict")]
Conflict,
#[error("feature not implemented")]
FeatureNotImplemented,
#[error("forbidden")]
Forbidden,
#[error("gone: {0:?}")]
Gone(Option<String>),
#[error("internal server error")]
InternalServerError,
#[error("item not found")]
ItemNotFound,
#[error("JID malformed")]
JIDMalformed,
#[error("not acceptable")]
JidMalformed,
NotAcceptable,
#[error("not allowed")]
NotAllowed,
#[error("not authorized")]
NotAuthorized,
#[error("policy violation")]
PolicyViolation,
#[error("recipient unavailable")]
RecipientUnavailable,
#[error("redirect: {0:?}")]
Redirect(Option<String>),
#[error("registration required")]
RegistrationRequired,
#[error("remote server not found")]
RemoteServerNotFound,
#[error("remote server timeout")]
RemoteServerTimeout,
#[error("resource constraint")]
ResourceConstraint,
#[error("service unavailable")]
ServiceUnavailable,
#[error("subscription required")]
SubscriptionRequired,
#[error("undefined condition")]
UndefinedCondition,
#[error("unexpected request")]
UnexpectedRequest,
}
@ -67,7 +44,7 @@ impl FromElement for Error {
(Some(XMLNS), "gone") => return Ok(Error::Gone(element.pop_value_opt()?)),
(Some(XMLNS), "internal-server-error") => error = Error::InternalServerError,
(Some(XMLNS), "item-not-found") => error = Error::ItemNotFound,
(Some(XMLNS), "jid-malformed") => error = Error::JIDMalformed,
(Some(XMLNS), "jid-malformed") => error = Error::JidMalformed,
(Some(XMLNS), "not-acceptable") => error = Error::NotAcceptable,
(Some(XMLNS), "not-allowed") => error = Error::NotAllowed,
(Some(XMLNS), "not-authorized") => error = Error::NotAuthorized,
@ -101,7 +78,7 @@ impl IntoElement for Error {
Error::Gone(r) => Element::builder("gone", Some(XMLNS)).push_text_opt(r.clone()),
Error::InternalServerError => Element::builder("internal-server-error", Some(XMLNS)),
Error::ItemNotFound => Element::builder("item-not-found", Some(XMLNS)),
Error::JIDMalformed => Element::builder("jid-malformed", Some(XMLNS)),
Error::JidMalformed => Element::builder("jid-malformed", Some(XMLNS)),
Error::NotAcceptable => Element::builder("not-acceptable", Some(XMLNS)),
Error::NotAllowed => Element::builder("not-allowed", Some(XMLNS)),
Error::NotAuthorized => Element::builder("not-authorized", Some(XMLNS)),
@ -125,7 +102,7 @@ impl IntoElement for Error {
#[derive(Clone, Debug)]
pub struct Text {
lang: Option<String>,
pub text: Option<String>,
text: Option<String>,
}
impl FromElement for Text {

View File

@ -1,5 +1,7 @@
use std::collections::{HashMap, HashSet};
use peanuts::{
element::{FromElement, IntoElement},
element::{Content, FromElement, IntoElement, Name, NamespaceDeclaration},
Element,
};

View File

@ -1,16 +1,15 @@
use std::fmt::Display;
use std::collections::{HashMap, HashSet};
use jid::JID;
use peanuts::element::{ElementBuilder, FromElement, IntoElement};
use peanuts::Element;
use thiserror::Error;
use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration};
use peanuts::{element::Name, Element};
use crate::bind;
use super::client;
use super::sasl::{self, Mechanisms};
use super::starttls::{self, StartTls};
use super::stream_error::{Error as StreamError, Text};
use super::{client, stream_error};
pub const XMLNS: &str = "http://etherx.jabber.org/streams";
@ -179,24 +178,12 @@ impl FromElement for Feature {
}
}
#[derive(Error, Debug, Clone)]
#[derive(Debug)]
pub struct Error {
error: StreamError,
text: Option<Text>,
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)?;
if let Some(text) = &self.text {
if let Some(text) = &text.text {
write!(f, ": {}", text)?;
}
}
Ok(())
}
}
impl FromElement for Error {
fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
element.check_name("error")?;

View File

@ -2,61 +2,35 @@ use peanuts::{
element::{FromElement, IntoElement},
DeserializeError, Element, XML_NS,
};
use thiserror::Error;
pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-streams";
#[derive(Error, Clone, Debug)]
#[derive(Clone, Debug)]
pub enum Error {
#[error("bad format")]
BadFormat,
#[error("bad namespace prefix")]
BadNamespacePrefix,
#[error("conflict")]
Conflict,
#[error("connection timeout")]
ConnectionTimeout,
#[error("host gone")]
HostGone,
#[error("host unknown")]
HostUnknown,
#[error("improper addressing")]
ImproperAddressing,
#[error("internal server error")]
InternalServerError,
#[error("invalid from")]
InvalidFrom,
#[error("invalid id")]
InvalidId,
#[error("invalid namespace")]
InvalidNamespace,
#[error("invalid xml")]
InvalidXml,
#[error("not authorized")]
NotAuthorized,
#[error("not well formed")]
NotWellFormed,
#[error("policy violation")]
PolicyViolation,
#[error("remote connection failed")]
RemoteConnectionFailed,
#[error("reset")]
Reset,
#[error("resource constraint")]
ResourceConstraint,
#[error("restricted xml")]
RestrictedXml,
#[error("see other host: {0:?}")]
SeeOtherHost(Option<String>),
#[error("system shutdown")]
SystemShutdown,
#[error("undefined condition")]
UndefinedCondition,
#[error("unsupported encoding")]
UnsupportedEncoding,
#[error("unsupported stanza type")]
UnsupportedStanzaType,
#[error("unsupported version")]
UnsupportedVersion,
}
@ -138,7 +112,7 @@ impl IntoElement for Error {
#[derive(Clone, Debug)]
pub struct Text {
pub text: Option<String>,
text: Option<String>,
lang: Option<String>,
}