diff --git a/src/ai.rs b/src/ai.rs deleted file mode 100644 index baea564..0000000 --- a/src/ai.rs +++ /dev/null @@ -1,38 +0,0 @@ -use anyhow::Result; -use chat_gpt_lib_rs::{ChatGPTClient, ChatInput, Message, Model, Role}; - -pub type AI = ChatGPTClient; - -pub fn make_ai_client() -> AI { - // TODO: Take key from env vars - let api_key = "fake-api-key"; - let base_url = "https://api.openai.com"; - return ChatGPTClient::new(api_key, base_url); -} - -pub async fn infer_country_of_living( - ai: &AI, - display_name: &str, - description: &str, -) -> Result { - let chat_input = ChatInput { - model: Model::Gpt3_5Turbo, - messages: vec![ - Message { - role: Role::System, - // TODO: Lol, prompt injection much? - content: "You are a tool that attempts to guess where a person is likely to be from based on their name and short bio. Please respond with two-letter country code only. Use lowercase letters.".to_string(), - }, - Message { - role: Role::User, - content: format!("Name: {display_name}\nBio:\n{description}"), - }, - ], - ..Default::default() - }; - - let response = ai.chat(chat_input).await?; - - // TODO: Error handling? - return Ok(response.choices[0].message.content.clone()); -} diff --git a/src/database.rs b/src/database.rs deleted file mode 100644 index f417095..0000000 --- a/src/database.rs +++ /dev/null @@ -1,109 +0,0 @@ -use anyhow::Result; -use chrono::{DateTime, Utc}; - -use scooby::postgres::{insert_into, select, update, Parameters}; -use sqlx::Row; - -use sqlx::postgres::{PgPool, PgPoolOptions, PgRow}; -use sqlx::query; - -pub type ConnectionPool = PgPool; - -pub struct Post { - indexed_at: DateTime, - author_did: String, - cid: String, - uri: String, -} - -pub struct Profile { - first_seen_at: DateTime, - did: String, - has_been_processed: bool, - likely_country_of_living: Option, -} - -pub struct SubscriptionState { - service: String, - cursor: i64, -} - -pub async fn make_connection_pool() -> Result { - // TODO: get options from env vars - Ok(PgPoolOptions::new() - .max_connections(5) - .connect("postgres://postgres:password@localhost/nederlandskie") - .await?) -} - -pub async fn insert_post( - db: &ConnectionPool, - author_did: &str, - cid: &str, - uri: &str, -) -> Result<()> { - let mut params = Parameters::new(); - - Ok(query( - &insert_into("Post") - .columns(("author_did", "cid", "uri")) - .values([params.next_array()]) - .to_string(), - ) - .bind(author_did) - .bind(cid) - .bind(uri) - .execute(db) - .await - .map(|_| ())?) -} - -pub async fn insert_profile_if_it_doesnt_exist(db: &ConnectionPool, did: &str) -> Result { - let mut params = Parameters::new(); - - Ok(query( - &insert_into("Profile") - .columns(("did",)) - .values([params.next()]) - .on_conflict() - .do_nothing() - .to_string(), - ) - .bind(did) - .execute(db) - .await - .map(|result| result.rows_affected() > 0)?) -} - -pub async fn fetch_unprocessed_profile_dids(db: &ConnectionPool) -> Result> { - Ok(query( - &select("did") - .from("Profile") - .where_("has_been_processed = FALSE") - .to_string(), - ) - .map(|r: PgRow| r.get(0)) - .fetch_all(db) - .await?) -} - -pub async fn store_profile_details( - db: &ConnectionPool, - did: &str, - likely_country_of_living: &str, -) -> Result { - let mut params = Parameters::new(); - - Ok(query( - &update("Profile") - .set("has_been_processed", "TRUE") - .set("likely_country_of_living", params.next()) - .where_(format!("did = {}", params.next())) - .to_string(), - ) - .bind(likely_country_of_living) - .bind(did) - .execute(db) - .await - .map(|result| result.rows_affected() > 0)?) -} diff --git a/src/main.rs b/src/main.rs index 4c903a2..55b9ba6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,79 +1,26 @@ -mod ai; -mod database; -mod frames; -mod profile_classifying; -mod streaming; +mod processes; +mod services; -use crate::profile_classifying::classify_unclassified_profiles; -use ai::make_ai_client; use anyhow::Result; -use async_trait::async_trait; -use crate::database::{ - insert_post, insert_profile_if_it_doesnt_exist, make_connection_pool, ConnectionPool, -}; -use crate::streaming::{start_processing_operations_with, Operation, OperationProcessor}; +use crate::processes::post_saver::PostSaver; +use crate::processes::profile_classifier::ProfileClassifier; +use crate::services::ai::AI; +use crate::services::bluesky::Bluesky; +use crate::services::database::Database; #[tokio::main] async fn main() -> Result<()> { - let db_connection_pool = make_connection_pool().await?; - let ai_client = make_ai_client(); + // TODO: Use env vars + let ai = AI::new("fake-api-key", "https://api.openai.com"); + let bluesky = Bluesky::new("https://bsky.social"); + let database = + Database::connect("postgres://postgres:password@localhost/nederlandskie").await?; - // FIXME: This struct shouldn't really exist, but I couldn't find a way to replace - // this whole nonsense with a closure, which is what this whole thing should be in - // first place. - let post_saver = PostSaver { - db_connection_pool: db_connection_pool.clone(), - }; + let post_saver = PostSaver::new(&database, &bluesky); + let profile_classifier = ProfileClassifier::new(&database, &ai, &bluesky); - tokio::try_join!( - start_processing_operations_with(post_saver), - classify_unclassified_profiles(db_connection_pool.clone(), ai_client) - )?; + tokio::try_join!(post_saver.start(), profile_classifier.start())?; Ok(()) } - -struct PostSaver { - db_connection_pool: ConnectionPool, -} - -#[async_trait] -impl OperationProcessor for PostSaver { - async fn process_operation(&self, operation: &Operation) -> Result<()> { - match operation { - Operation::CreatePost { - author_did, - cid, - uri, - languages, - text, - } => { - // TODO: Configure this via env vars - if !languages.contains("ru") { - return Ok(()); - } - - // BlueSky gets confused a lot about Russian vs Ukrainian, so skip posts - // that may be in Ukrainian regardless of whether Russian is in the list - // TODO: Configure this via env vars - if languages.contains("uk") { - return Ok(()); - } - - println!("received insertable post from {author_did}: {text}"); - - insert_profile_if_it_doesnt_exist(&self.db_connection_pool, &author_did).await?; - insert_post(&self.db_connection_pool, &author_did, &cid, &uri).await?; - } - Operation::DeletePost { uri } => { - println!("received a post do delete: {uri}"); - - // TODO: Delete posts from db - // delete_post(&self.db_connection_pool, &uri).await?; - } - }; - - Ok(()) - } -} diff --git a/src/processes/mod.rs b/src/processes/mod.rs new file mode 100644 index 0000000..a73ea87 --- /dev/null +++ b/src/processes/mod.rs @@ -0,0 +1,2 @@ +pub mod post_saver; +pub mod profile_classifier; diff --git a/src/processes/post_saver.rs b/src/processes/post_saver.rs new file mode 100644 index 0000000..955a007 --- /dev/null +++ b/src/processes/post_saver.rs @@ -0,0 +1,64 @@ +use anyhow::Result; +use async_trait::async_trait; + +use crate::services::bluesky::{Bluesky, Operation, OperationProcessor}; +use crate::services::database::Database; + +pub struct PostSaver<'a, 'b> { + database: &'a Database, + bluesky: &'b Bluesky, +} + +impl<'a, 'b> PostSaver<'a, 'b> { + pub fn new(database: &'a Database, bluesky: &'b Bluesky) -> Self { + Self { database, bluesky } + } +} + +impl<'a, 'b> PostSaver<'a, 'b> { + pub async fn start(&self) -> Result<()> { + Ok(self.bluesky.subscribe_to_operations(self).await?) + } +} + +#[async_trait] +impl<'a, 'b> OperationProcessor for PostSaver<'a, 'b> { + async fn process_operation(&self, operation: &Operation) -> Result<()> { + match operation { + Operation::CreatePost { + author_did, + cid, + uri, + languages, + text, + } => { + // TODO: Configure this via env vars + if !languages.contains("ru") { + return Ok(()); + } + + // BlueSky gets confused a lot about Russian vs Ukrainian, so skip posts + // that may be in Ukrainian regardless of whether Russian is in the list + // TODO: Configure this via env vars + if languages.contains("uk") { + return Ok(()); + } + + println!("received insertable post from {author_did}: {text}"); + + self.database + .insert_profile_if_it_doesnt_exist(&author_did) + .await?; + self.database.insert_post(&author_did, &cid, &uri).await?; + } + Operation::DeletePost { uri } => { + println!("received a post do delete: {uri}"); + + // TODO: Delete posts from db + // self.database.delete_post(&self.db_connection_pool, &uri).await?; + } + }; + + Ok(()) + } +} diff --git a/src/processes/profile_classifier.rs b/src/processes/profile_classifier.rs new file mode 100644 index 0000000..ec95766 --- /dev/null +++ b/src/processes/profile_classifier.rs @@ -0,0 +1,57 @@ +use std::time::Duration; + +use anyhow::Result; + +use crate::services::ai::AI; +use crate::services::bluesky::Bluesky; +use crate::services::database::Database; + +pub struct ProfileClassifier<'a, 'b, 'c> { + database: &'a Database, + ai: &'b AI, + bluesky: &'c Bluesky, +} + +impl<'a, 'b, 'c> ProfileClassifier<'a, 'b, 'c> { + pub fn new(database: &'a Database, ai: &'b AI, bluesky: &'c Bluesky) -> Self { + Self { + database, + ai, + bluesky, + } + } + + pub async fn start(&self) -> Result<()> { + loop { + // TODO: Don't just exit this function when an error happens, just wait a minute or so? + self.classify_unclassified_profiles().await?; + } + } + + async fn classify_unclassified_profiles(&self) -> Result<()> { + // TODO: Maybe streamify this so that each thing is processed in parallel + + let dids = self.database.fetch_unprocessed_profile_dids().await?; + if dids.is_empty() { + println!("No profiles to process: waiting 10 seconds"); + tokio::time::sleep(Duration::from_secs(10)).await; + } else { + for did in &dids { + self.fill_in_profile_details(did).await?; + } + } + + Ok(()) + } + + async fn fill_in_profile_details(&self, did: &str) -> Result<()> { + let details = self.bluesky.fetch_profile_details(did).await?; + let country = self + .ai + .infer_country_of_living(&details.display_name, &details.description) + .await?; + self.database.store_profile_details(did, &country).await?; + println!("Stored inferred country of living for {did}: {country}"); + Ok(()) + } +} diff --git a/src/profile_classifying.rs b/src/profile_classifying.rs deleted file mode 100644 index 5285bec..0000000 --- a/src/profile_classifying.rs +++ /dev/null @@ -1,66 +0,0 @@ -use anyhow::anyhow; -use std::time::Duration; - -use anyhow::Result; -use atrium_api::client::AtpServiceClient; -use atrium_api::xrpc::client::reqwest::ReqwestClient; - -use crate::ai::{infer_country_of_living, AI}; -use crate::database::{fetch_unprocessed_profile_dids, store_profile_details, ConnectionPool}; - -#[derive(Debug)] -struct ProfileDetails { - display_name: String, - description: String, -} - -pub async fn classify_unclassified_profiles(db: ConnectionPool, ai: AI) -> Result<()> { - loop { - // TODO: Maybe streamify this so that each thing is processed in parallel - // TODO: Also don't just exit this function when an error happens, just wait a minute or so? - let dids = fetch_unprocessed_profile_dids(&db).await?; - if dids.is_empty() { - println!("No profiles to process: waiting 10 seconds"); - tokio::time::sleep(Duration::from_secs(10)).await; - } else { - for did in &dids { - fill_in_profile_details(&db, &ai, did).await?; - } - } - } -} - -async fn fill_in_profile_details(db: &ConnectionPool, ai: &AI, did: &str) -> Result<()> { - let details = fetch_profile_details(did).await?; - let country = infer_country_of_living(ai, &details.display_name, &details.description).await?; - store_profile_details(db, did, &country).await?; - println!("Stored inferred country of living for {did}: {country}"); - Ok(()) -} - -async fn fetch_profile_details(did: &str) -> Result { - let client = AtpServiceClient::new(ReqwestClient::new("https://bsky.social".into())); - - let result = client - .service - .com - .atproto - .repo - .get_record(atrium_api::com::atproto::repo::get_record::Parameters { - collection: "app.bsky.actor.profile".to_owned(), - cid: None, - repo: did.to_owned(), - rkey: "self".to_owned(), - }) - .await?; - - let profile = match result.value { - atrium_api::records::Record::AppBskyActorProfile(profile) => profile, - _ => return Err(anyhow!("Big bad, no such profile")), - }; - - Ok(ProfileDetails { - display_name: profile.display_name.unwrap_or_else(String::new), - description: profile.description.unwrap_or_else(String::new), - }) -} diff --git a/src/services/ai.rs b/src/services/ai.rs new file mode 100644 index 0000000..9809b0b --- /dev/null +++ b/src/services/ai.rs @@ -0,0 +1,41 @@ +use anyhow::Result; +use chat_gpt_lib_rs::{ChatGPTClient, ChatInput, Message, Model, Role}; + +pub struct AI { + chat_gpt_client: ChatGPTClient, +} + +impl AI { + pub fn new(api_key: &str, base_url: &str) -> Self { + Self { + chat_gpt_client: ChatGPTClient::new(api_key, base_url), + } + } + + pub async fn infer_country_of_living( + &self, + display_name: &str, + description: &str, + ) -> Result { + let chat_input = ChatInput { + model: Model::Gpt3_5Turbo, + messages: vec![ + Message { + role: Role::System, + // TODO: Lol, prompt injection much? + content: "You are a tool that attempts to guess where a person is likely to be from based on their name and short bio. Please respond with two-letter country code only. Use lowercase letters.".to_string(), + }, + Message { + role: Role::User, + content: format!("Name: {display_name}\nBio:\n{description}"), + }, + ], + ..Default::default() + }; + + let response = self.chat_gpt_client.chat(chat_input).await?; + + // TODO: Error handling? + return Ok(response.choices[0].message.content.clone()); + } +} diff --git a/src/services/bluesky/client.rs b/src/services/bluesky/client.rs new file mode 100644 index 0000000..490578d --- /dev/null +++ b/src/services/bluesky/client.rs @@ -0,0 +1,68 @@ +use anyhow::{anyhow, Result}; +use atrium_api::client::AtpServiceClient; +use atrium_api::client::AtpServiceWrapper; +use atrium_xrpc::client::reqwest::ReqwestClient; +use futures::StreamExt; +use tokio_tungstenite::{connect_async, tungstenite}; + +use super::streaming::{handle_message, OperationProcessor}; + +#[derive(Debug)] +pub struct ProfileDetails { + pub display_name: String, + pub description: String, +} + +pub struct Bluesky { + client: AtpServiceClient>, +} + +impl Bluesky { + pub fn new(host: &str) -> Self { + Self { + client: AtpServiceClient::new(ReqwestClient::new(host.to_owned())), + } + } + + pub async fn fetch_profile_details(&self, did: &str) -> Result { + let result = self + .client + .service + .com + .atproto + .repo + .get_record(atrium_api::com::atproto::repo::get_record::Parameters { + collection: "app.bsky.actor.profile".to_owned(), + cid: None, + repo: did.to_owned(), + rkey: "self".to_owned(), + }) + .await?; + + let profile = match result.value { + atrium_api::records::Record::AppBskyActorProfile(profile) => profile, + _ => return Err(anyhow!("Big bad, no such profile")), + }; + + Ok(ProfileDetails { + display_name: profile.display_name.unwrap_or_else(String::new), + description: profile.description.unwrap_or_else(String::new), + }) + } + + pub async fn subscribe_to_operations( + &self, + processor: &P, + ) -> Result<()> { + let (mut stream, _) = + connect_async("wss://bsky.social/xrpc/com.atproto.sync.subscribeRepos").await?; + + while let Some(Ok(tungstenite::Message::Binary(message))) = stream.next().await { + if let Err(e) = handle_message(&message, processor).await { + println!("Error handling a message: {:?}", e); + } + } + + Ok(()) + } +} diff --git a/src/services/bluesky/mod.rs b/src/services/bluesky/mod.rs new file mode 100644 index 0000000..91a0f04 --- /dev/null +++ b/src/services/bluesky/mod.rs @@ -0,0 +1,6 @@ +mod client; +mod proto; +mod streaming; + +pub use client::Bluesky; +pub use streaming::{Operation, OperationProcessor}; diff --git a/src/frames.rs b/src/services/bluesky/proto.rs similarity index 100% rename from src/frames.rs rename to src/services/bluesky/proto.rs diff --git a/src/streaming.rs b/src/services/bluesky/streaming.rs similarity index 81% rename from src/streaming.rs rename to src/services/bluesky/streaming.rs index 4601c18..10829ab 100644 --- a/src/streaming.rs +++ b/src/services/bluesky/streaming.rs @@ -3,13 +3,11 @@ use std::collections::HashSet; use anyhow::Result; use async_trait::async_trait; -use crate::frames::Frame; +use super::proto::Frame; use anyhow::anyhow; use atrium_api::app::bsky::feed::post::Record; use atrium_api::com::atproto::sync::subscribe_repos::Commit; use atrium_api::com::atproto::sync::subscribe_repos::Message; -use futures::StreamExt; -use tokio_tungstenite::{connect_async, tungstenite}; #[async_trait] pub trait OperationProcessor { @@ -30,20 +28,7 @@ pub enum Operation { }, } -pub async fn start_processing_operations_with(processor: P) -> Result<()> { - let (mut stream, _) = - connect_async("wss://bsky.social/xrpc/com.atproto.sync.subscribeRepos").await?; - - while let Some(Ok(tungstenite::Message::Binary(message))) = stream.next().await { - if let Err(e) = handle_message(&message, &processor).await { - println!("Error handling a message: {:?}", e); - } - } - - Ok(()) -} - -async fn handle_message(message: &[u8], processor: &P) -> Result<()> { +pub async fn handle_message(message: &[u8], processor: &P) -> Result<()> { let commit = match parse_commit_from_message(&message)? { Some(commit) => commit, None => return Ok(()), diff --git a/src/services/database.rs b/src/services/database.rs new file mode 100644 index 0000000..e9efe1f --- /dev/null +++ b/src/services/database.rs @@ -0,0 +1,104 @@ +use anyhow::Result; +use chrono::{DateTime, Utc}; +use scooby::postgres::{insert_into, select, update, Parameters}; +use sqlx::postgres::{PgPool, PgPoolOptions, PgRow}; +use sqlx::query; +use sqlx::Row; + +pub struct Post { + indexed_at: DateTime, + author_did: String, + cid: String, + uri: String, +} + +pub struct Profile { + first_seen_at: DateTime, + did: String, + has_been_processed: bool, + likely_country_of_living: Option, +} + +pub struct SubscriptionState { + service: String, + cursor: i64, +} + +pub struct Database { + connection_pool: PgPool, +} + +impl Database { + pub async fn connect(url: &str) -> Result { + Ok(Self { + connection_pool: PgPoolOptions::new().max_connections(5).connect(url).await?, + }) + } + + pub async fn insert_post(&self, author_did: &str, cid: &str, uri: &str) -> Result<()> { + let mut params = Parameters::new(); + + Ok(query( + &insert_into("Post") + .columns(("author_did", "cid", "uri")) + .values([params.next_array()]) + .to_string(), + ) + .bind(author_did) + .bind(cid) + .bind(uri) + .execute(&self.connection_pool) + .await + .map(|_| ())?) + } + + pub async fn insert_profile_if_it_doesnt_exist(&self, did: &str) -> Result { + let mut params = Parameters::new(); + + Ok(query( + &insert_into("Profile") + .columns(("did",)) + .values([params.next()]) + .on_conflict() + .do_nothing() + .to_string(), + ) + .bind(did) + .execute(&self.connection_pool) + .await + .map(|result| result.rows_affected() > 0)?) + } + + pub async fn fetch_unprocessed_profile_dids(&self) -> Result> { + Ok(query( + &select("did") + .from("Profile") + .where_("has_been_processed = FALSE") + .to_string(), + ) + .map(|r: PgRow| r.get(0)) + .fetch_all(&self.connection_pool) + .await?) + } + + pub async fn store_profile_details( + &self, + did: &str, + likely_country_of_living: &str, + ) -> Result { + let mut params = Parameters::new(); + + Ok(query( + &update("Profile") + .set("has_been_processed", "TRUE") + .set("likely_country_of_living", params.next()) + .where_(format!("did = {}", params.next())) + .to_string(), + ) + .bind(likely_country_of_living) + .bind(did) + .execute(&self.connection_pool) + .await + .map(|result| result.rows_affected() > 0)?) + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs new file mode 100644 index 0000000..468ea61 --- /dev/null +++ b/src/services/mod.rs @@ -0,0 +1,3 @@ +pub mod ai; +pub mod bluesky; +pub mod database;