diff --git a/Cargo.lock b/Cargo.lock index 0ea208c..450693c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -266,6 +266,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cbbc9d0964165b47557570cce6c952866c2678457aca742aafc9fb771d30270" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.4" @@ -1279,6 +1285,21 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jwt" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6204285f77fe7d9784db3fdc449ecce1a0114927a51d5a41c4c7a292011c015f" +dependencies = [ + "base64 0.13.1", + "crypto-common", + "digest", + "hmac", + "serde", + "serde_json", + "sha2", +] + [[package]] name = "keccak" version = "0.1.4" @@ -2311,10 +2332,13 @@ dependencies = [ "dotenv", "env_logger", "futures", + "http", + "jwt", "libipld-core", "lingua", "log", "once_cell", + "reqwest", "rs-car", "scooby", "serde", @@ -2757,11 +2781,11 @@ checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "reqwest" -version = "0.11.20" +version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9ad3fe7488d7e34558a2033d45a0c90b72d97b4f80705666fea71472e2e6a1" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ - "base64", + "base64 0.21.4", "bytes", "encoding_rs", "futures-core", @@ -2782,6 +2806,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", "tokio-native-tls", "tower-service", @@ -3235,7 +3260,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "864b869fdf56263f4c95c45483191ea0af340f9f3e3e7b4d57a61c7c87a970db" dependencies = [ "atoi", - "base64", + "base64 0.21.4", "bitflags 2.4.0", "byteorder", "bytes", @@ -3278,7 +3303,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb7ae0e6a97fb3ba33b23ac2671a5ce6e3cabe003f451abd5a56e7951d975624" dependencies = [ "atoi", - "base64", + "base64 0.21.4", "bitflags 2.4.0", "byteorder", "chrono", @@ -3422,6 +3447,27 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.8.0" diff --git a/Cargo.toml b/Cargo.toml index cbe9610..12fa115 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,10 +18,13 @@ clap = { version = "4.4.6", features = ["derive"] } dotenv = "0.15.0" env_logger = "0.10.0" futures = "0.3.28" +http = "0.2.9" +jwt = "0.16.0" libipld-core = { version = "0.16.0", features = ["serde-codec"] } lingua = "1.5.0" log = "0.4.20" once_cell = "1.18.0" +reqwest = "0.11.22" rs-car = "0.4.1" scooby = "0.5.0" serde = "1.0.188" diff --git a/README.md b/README.md index 0c88f4e..af193fb 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Not fully complete yet, see roadmap. - [x] Handle missing profiles in the profile classifier - [x] Add a way to mark a profile as being from a certain country manually - [x] Handle reconnecting to websocket somehow -- [ ] Publish the feed +- [x] Publish the feed ## Configuration diff --git a/src/bin/force_profile_country.rs b/src/bin/force_profile_country.rs index 3f41d21..e7aec15 100644 --- a/src/bin/force_profile_country.rs +++ b/src/bin/force_profile_country.rs @@ -28,7 +28,7 @@ async fn main() -> Result<()> { let database_url = env::var("DATABASE_URL").context("DATABASE_URL environment variable must be set")?; - let bluesky = Bluesky::new("https://bsky.social"); + let bluesky = Bluesky::unauthenticated("https://bsky.social"); let database = Database::connect(&database_url).await?; for handle in &args.handle { diff --git a/src/bin/publish_feed.rs b/src/bin/publish_feed.rs index 1d4bcad..d339e41 100644 --- a/src/bin/publish_feed.rs +++ b/src/bin/publish_feed.rs @@ -41,19 +41,21 @@ async fn main() -> Result<()> { let feed_generator_did = format!("did:web:{}", env::var("FEED_GENERATOR_HOSTNAME")?); - let bluesky = Bluesky::new("https://bsky.social"); + println!("Logging in"); - let session = bluesky.login(&handle, &password).await?; + let bluesky = Bluesky::login("https://bsky.social", &handle, &password).await?; let mut avatar = None; if let Some(path) = args.avatar_filename { let bytes = std::fs::read(path)?; avatar = Some(bluesky.upload_blob(bytes).await?); + println!("Uploaded avatar"); } + bluesky .publish_feed( - &session.did, + &bluesky.session().unwrap().did, &feed_generator_did, &args.name, &args.display_name, diff --git a/src/bin/who_am_i.rs b/src/bin/who_am_i.rs index a4a6176..14f856c 100644 --- a/src/bin/who_am_i.rs +++ b/src/bin/who_am_i.rs @@ -2,7 +2,7 @@ extern crate nederlandskie; use std::env; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, anyhow}; use dotenv::dotenv; use nederlandskie::services::Bluesky; @@ -11,15 +11,14 @@ use nederlandskie::services::Bluesky; async fn main() -> Result<()> { dotenv()?; - let bluesky = Bluesky::new("https://bsky.social"); - let handle = env::var("PUBLISHER_BLUESKY_HANDLE") .context("PUBLISHER_BLUESKY_HANDLE environment variable must be set")?; let password = env::var("PUBLISHER_BLUESKY_PASSWORD") .context("PUBLISHER_BLUESKY_PASSWORD environment variable must be set")?; - let session = bluesky.login(&handle, &password).await?; + let bluesky = Bluesky::login("https://bsky.social", &handle, &password).await?; + let session = bluesky.session().ok_or_else(|| anyhow!("Could not log in"))?; println!("{}", session.did); diff --git a/src/main.rs b/src/main.rs index 23c233d..44be1ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,7 +23,7 @@ async fn main() -> Result<()> { info!("Initializing service clients"); let ai = Arc::new(AI::new(&config.chat_gpt_api_key, "https://api.openai.com")); - let bluesky = Arc::new(Bluesky::new("https://bsky.social")); + let bluesky = Arc::new(Bluesky::unauthenticated("https://bsky.social")); let database = Arc::new(Database::connect(&config.database_url).await?); info!("Initializing language detector"); diff --git a/src/services/bluesky.rs b/src/services/bluesky.rs index e8ea666..1abd439 100644 --- a/src/services/bluesky.rs +++ b/src/services/bluesky.rs @@ -1,7 +1,9 @@ mod client; mod decode; mod proto; +mod session; mod streaming; +mod xrpc_client; pub use client::Bluesky; pub use streaming::{CommitDetails, CommitProcessor, Operation}; diff --git a/src/services/bluesky/client.rs b/src/services/bluesky/client.rs index 89bec1d..6ed3ad6 100644 --- a/src/services/bluesky/client.rs +++ b/src/services/bluesky/client.rs @@ -1,18 +1,21 @@ use std::matches; +use std::sync::Arc; +use std::sync::Mutex; use anyhow::{anyhow, Result}; use atrium_api::blob::BlobRef; use atrium_api::client::AtpServiceClient; use atrium_api::client::AtpServiceWrapper; use atrium_api::records::Record; -use atrium_xrpc::client::reqwest::ReqwestClient; use axum::http::StatusCode; use chrono::Utc; use futures::StreamExt; use log::error; use tokio_tungstenite::{connect_async, tungstenite}; +use super::session::Session; use super::streaming::{handle_message, CommitProcessor}; +use super::xrpc_client::AuthenticateableXrpcClient; #[derive(Debug)] pub struct ProfileDetails { @@ -20,27 +23,25 @@ pub struct ProfileDetails { pub description: String, } -#[derive(Debug)] -pub struct SessionDetails { - pub did: String, -} - pub struct Bluesky { - client: AtpServiceClient>, + client: AtpServiceClient>, + session: Option>> } impl Bluesky { - pub fn new(host: &str) -> Self { + pub fn unauthenticated(host: &str) -> Self { Self { - client: AtpServiceClient::new(ReqwestClient::new(host.to_owned())), + client: AtpServiceClient::new(AuthenticateableXrpcClient::new(host.to_owned())), + session: None } } - pub async fn login(&self, handle: &str, password: &str) -> Result { + pub async fn login(host: &str, handle: &str, password: &str) -> Result { use atrium_api::com::atproto::server::create_session::Input; - let result = self - .client + let client = AtpServiceClient::new(AuthenticateableXrpcClient::new(host.to_owned())); + + let result = client .service .com .atproto @@ -51,10 +52,26 @@ impl Bluesky { }) .await?; - Ok(SessionDetails { did: result.did }) + let session = Arc::new(Mutex::new(result.try_into()?)); + + let authenticated_client = AtpServiceClient::new(AuthenticateableXrpcClient::with_session( + host.to_owned(), + session.clone() + )); + + Ok(Self { + client: authenticated_client, + session: Some(session) + }) + } + + pub fn session(&self) -> Option { + self.session.as_ref().and_then(|s| s.lock().ok()).map(|s| s.clone()) } pub async fn upload_blob(&self, blob: Vec) -> Result { + self.ensure_token_valid().await?; + let result = self .client .service @@ -78,6 +95,8 @@ impl Bluesky { ) -> Result<()> { use atrium_api::com::atproto::repo::put_record::Input; + self.ensure_token_valid().await?; + self.client .service .com @@ -88,7 +107,7 @@ impl Bluesky { record: Record::AppBskyFeedGenerator(Box::new( atrium_api::app::bsky::feed::generator::Record { avatar, - created_at: Utc::now().to_string(), + created_at: Utc::now().to_rfc3339(), description: Some(description.to_owned()), description_facets: None, did: feed_generator_did.to_owned(), @@ -183,6 +202,27 @@ impl Bluesky { Ok(()) } + + async fn ensure_token_valid(&self) -> Result<()> { + let access_jwt_exp = + self.session.as_ref().ok_or_else(|| anyhow!("Not authenticated"))?.lock().map_err(|e| anyhow!("session mutex is poisoned: {e}"))?.access_jwt_exp; + + let jwt_expired = Utc::now() > access_jwt_exp; + + if jwt_expired { + let refreshed = self.client.service.com.atproto.server.refresh_session().await?; + + let mut session = self.session + .as_ref() + .ok_or_else(|| anyhow!("Not authenticated"))? + .lock() + .map_err(|e| anyhow!("session mutex is poisoned: {e}"))?; + + *session = refreshed.try_into()?; + } + + Ok(()) + } } fn is_missing_record_error(error: &atrium_xrpc::error::Error) -> bool { diff --git a/src/services/bluesky/session.rs b/src/services/bluesky/session.rs new file mode 100644 index 0000000..1d3aa7d --- /dev/null +++ b/src/services/bluesky/session.rs @@ -0,0 +1,57 @@ +use anyhow::{anyhow, Result}; +use atrium_api::com::atproto::server::create_session::Output as CreateSessionOutput; +use atrium_api::com::atproto::server::refresh_session::Output as RefreshSessionOutput; +use chrono::{DateTime, TimeZone, Utc}; +use jwt::{Header, Token}; +use serde::Deserialize; + +#[derive(Clone, Debug)] +pub struct Session { + pub access_jwt: String, + pub access_jwt_exp: DateTime, + pub refresh_jwt: String, + pub did: String, +} + +#[derive(Deserialize)] +struct AtprotoClaims { + exp: i64, +} + +pub fn get_token_expiration(jwt_string: &str) -> Result> { + let token: Token = Token::parse_unverified(jwt_string)?; + let expiration_time = Utc + .timestamp_millis_opt(token.claims().exp) + .earliest() + .ok_or_else(|| anyhow!("couldn't interpret expiration timestamp"))?; + + Ok(expiration_time) +} + +impl TryInto for CreateSessionOutput { + type Error = anyhow::Error; + + fn try_into(self) -> Result { + let access_jwt_exp = get_token_expiration(&self.access_jwt)?; + Ok(Session { + access_jwt: self.access_jwt, + access_jwt_exp, + refresh_jwt: self.refresh_jwt, + did: self.did, + }) + } +} + +impl TryInto for RefreshSessionOutput { + type Error = anyhow::Error; + + fn try_into(self) -> Result { + let access_jwt_exp = get_token_expiration(&self.access_jwt)?; + Ok(Session { + access_jwt: self.access_jwt, + access_jwt_exp, + refresh_jwt: self.refresh_jwt, + did: self.did, + }) + } +} diff --git a/src/services/bluesky/xrpc_client.rs b/src/services/bluesky/xrpc_client.rs new file mode 100644 index 0000000..cdcf42b --- /dev/null +++ b/src/services/bluesky/xrpc_client.rs @@ -0,0 +1,66 @@ +use async_trait::async_trait; +use atrium_xrpc::{client::reqwest::ReqwestClient, HttpClient, XrpcClient}; +use http::{Request, Response, Method}; +use std::sync::{Arc, Mutex}; + +use super::session::Session; + +pub struct AuthenticateableXrpcClient { + inner: ReqwestClient, + session: Option>>, +} + +impl AuthenticateableXrpcClient { + pub fn new(host: String) -> Self { + Self { + inner: ReqwestClient::new(host), + session: None, + } + } + + pub fn with_session(host: String, session: Arc>) -> Self { + Self { + inner: ReqwestClient::new(host), + session: Some(session), + } + } +} + +#[async_trait] +impl HttpClient for AuthenticateableXrpcClient { + async fn send_http( + &self, + req: Request>, + ) -> Result>, Box> { + let (mut parts, body) = req.into_parts(); + + /* NOTE: This is a huge hack because auth is currently totally broken in atrium-api */ + let is_request_to_refresh_session = parts.method == Method::POST && parts.uri.to_string().ends_with("com.atproto.server.refreshSession"); + if let Some(token) = self.auth(is_request_to_refresh_session) { + parts.headers.insert(http::header::AUTHORIZATION, format!("Bearer {}", token).parse()?); + } + + let req = Request::from_parts(parts, body); + + self.inner.send_http(req).await + } +} + +impl XrpcClient for AuthenticateableXrpcClient { + fn auth(&self, is_refresh: bool) -> Option { + self.session + .as_ref() + .and_then(|session| session.lock().ok()) + .map(|session| { + if is_refresh { + session.refresh_jwt.clone() + } else { + session.access_jwt.clone() + } + }) + } + + fn host(&self) -> &str { + self.inner.host() + } +}