diff --git a/Cargo.lock b/Cargo.lock index 0cbdf72..92457d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "aho-corasick" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c378d78423fdad8089616f827526ee33c19f2fddbd5de1629152c9593ba4783" +dependencies = [ + "memchr", +] + [[package]] name = "allocator-api2" version = "0.2.16" @@ -257,6 +266,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chat-gpt-lib-rs" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae8651a0f3f7222ff1e22fd036f8e8cfffa7d6409dd495ddd83d55dc3a3777bf" +dependencies = [ + "env_logger", + "log", + "reqwest", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "chrono" version = "0.4.28" @@ -482,6 +505,19 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "env_logger" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -833,6 +869,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "hyper" version = "0.14.27" @@ -929,6 +971,17 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +[[package]] +name = "is-terminal" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +dependencies = [ + "hermit-abi", + "rustix", + "windows-sys", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1183,6 +1236,7 @@ dependencies = [ "async-trait", "atrium-api", "atrium-xrpc", + "chat-gpt-lib-rs", "chrono", "ciborium", "futures", @@ -1522,6 +1576,35 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "regex" +version = "1.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + [[package]] name = "reqwest" version = "0.11.20" @@ -2131,6 +2214,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "termcolor" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "1.0.47" @@ -2529,6 +2621,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 981bee4..d40ab45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ anyhow = "1.0.75" async-trait = "0.1.73" atrium-api = "0.6.0" atrium-xrpc = "0.4.0" +chat-gpt-lib-rs = "0.2.1" chrono = "0.4.26" ciborium = "0.2.1" futures = "0.3.28" diff --git a/README.md b/README.md index 0d32ac0..646d9a8 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Heavily WIP. Doesn't work yet at all, but does read the stream of posts as they - [x] Read stream of posts from Bluesky - [x] Store posts in the database - [x] Store user profiles in the database -- [ ] Detect the country of residence from profile information +- [x] Detect the country of residence from profile information - [ ] Keep subscription state to not lose messages - [ ] Serve the feed - [ ] Publish the feed diff --git a/sql/01_create_tables.sql b/sql/01_create_tables.sql index d626fec..83f720a 100644 --- a/sql/01_create_tables.sql +++ b/sql/01_create_tables.sql @@ -2,7 +2,7 @@ CREATE TABLE IF NOT EXISTS Profile { id INT GENERATED ALWAYS AS IDENTITY, first_seen_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), did TEXT UNIQUE, - handle TEXT NULL DEFAULT NULL, + has_been_processed BOOLEAN DEFAULT FALSE, likely_country_of_living varchar(2) NULL DEFAULT NULL } diff --git a/src/ai.rs b/src/ai.rs new file mode 100644 index 0000000..baea564 --- /dev/null +++ b/src/ai.rs @@ -0,0 +1,38 @@ +use anyhow::Result; +use chat_gpt_lib_rs::{ChatGPTClient, ChatInput, Message, Model, Role}; + +pub type AI = ChatGPTClient; + +pub fn make_ai_client() -> AI { + // TODO: Take key from env vars + let api_key = "fake-api-key"; + let base_url = "https://api.openai.com"; + return ChatGPTClient::new(api_key, base_url); +} + +pub async fn infer_country_of_living( + ai: &AI, + display_name: &str, + description: &str, +) -> Result { + let chat_input = ChatInput { + model: Model::Gpt3_5Turbo, + messages: vec![ + Message { + role: Role::System, + // TODO: Lol, prompt injection much? + content: "You are a tool that attempts to guess where a person is likely to be from based on their name and short bio. Please respond with two-letter country code only. Use lowercase letters.".to_string(), + }, + Message { + role: Role::User, + content: format!("Name: {display_name}\nBio:\n{description}"), + }, + ], + ..Default::default() + }; + + let response = ai.chat(chat_input).await?; + + // TODO: Error handling? + return Ok(response.choices[0].message.content.clone()); +} diff --git a/src/database.rs b/src/database.rs index 55d173f..f417095 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,8 +1,10 @@ use anyhow::Result; use chrono::{DateTime, Utc}; -use scooby::postgres::{insert_into, Parameters}; -use sqlx::postgres::{PgPool, PgPoolOptions}; +use scooby::postgres::{insert_into, select, update, Parameters}; +use sqlx::Row; + +use sqlx::postgres::{PgPool, PgPoolOptions, PgRow}; use sqlx::query; pub type ConnectionPool = PgPool; @@ -17,7 +19,7 @@ pub struct Post { pub struct Profile { first_seen_at: DateTime, did: String, - handle: Option, + has_been_processed: bool, likely_country_of_living: Option, } @@ -72,3 +74,36 @@ pub async fn insert_profile_if_it_doesnt_exist(db: &ConnectionPool, did: &str) - .await .map(|result| result.rows_affected() > 0)?) } + +pub async fn fetch_unprocessed_profile_dids(db: &ConnectionPool) -> Result> { + Ok(query( + &select("did") + .from("Profile") + .where_("has_been_processed = FALSE") + .to_string(), + ) + .map(|r: PgRow| r.get(0)) + .fetch_all(db) + .await?) +} + +pub async fn store_profile_details( + db: &ConnectionPool, + did: &str, + likely_country_of_living: &str, +) -> Result { + let mut params = Parameters::new(); + + Ok(query( + &update("Profile") + .set("has_been_processed", "TRUE") + .set("likely_country_of_living", params.next()) + .where_(format!("did = {}", params.next())) + .to_string(), + ) + .bind(likely_country_of_living) + .bind(did) + .execute(db) + .await + .map(|result| result.rows_affected() > 0)?) +} diff --git a/src/main.rs b/src/main.rs index d37341e..4c903a2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,11 @@ +mod ai; mod database; mod frames; +mod profile_classifying; mod streaming; +use crate::profile_classifying::classify_unclassified_profiles; +use ai::make_ai_client; use anyhow::Result; use async_trait::async_trait; @@ -13,13 +17,19 @@ use crate::streaming::{start_processing_operations_with, Operation, OperationPro #[tokio::main] async fn main() -> Result<()> { let db_connection_pool = make_connection_pool().await?; + let ai_client = make_ai_client(); // FIXME: This struct shouldn't really exist, but I couldn't find a way to replace // this whole nonsense with a closure, which is what this whole thing should be in // first place. - let post_saver = PostSaver { db_connection_pool }; + let post_saver = PostSaver { + db_connection_pool: db_connection_pool.clone(), + }; - start_processing_operations_with(post_saver).await?; + tokio::try_join!( + start_processing_operations_with(post_saver), + classify_unclassified_profiles(db_connection_pool.clone(), ai_client) + )?; Ok(()) } diff --git a/src/profile_classifying.rs b/src/profile_classifying.rs new file mode 100644 index 0000000..5285bec --- /dev/null +++ b/src/profile_classifying.rs @@ -0,0 +1,66 @@ +use anyhow::anyhow; +use std::time::Duration; + +use anyhow::Result; +use atrium_api::client::AtpServiceClient; +use atrium_api::xrpc::client::reqwest::ReqwestClient; + +use crate::ai::{infer_country_of_living, AI}; +use crate::database::{fetch_unprocessed_profile_dids, store_profile_details, ConnectionPool}; + +#[derive(Debug)] +struct ProfileDetails { + display_name: String, + description: String, +} + +pub async fn classify_unclassified_profiles(db: ConnectionPool, ai: AI) -> Result<()> { + loop { + // TODO: Maybe streamify this so that each thing is processed in parallel + // TODO: Also don't just exit this function when an error happens, just wait a minute or so? + let dids = fetch_unprocessed_profile_dids(&db).await?; + if dids.is_empty() { + println!("No profiles to process: waiting 10 seconds"); + tokio::time::sleep(Duration::from_secs(10)).await; + } else { + for did in &dids { + fill_in_profile_details(&db, &ai, did).await?; + } + } + } +} + +async fn fill_in_profile_details(db: &ConnectionPool, ai: &AI, did: &str) -> Result<()> { + let details = fetch_profile_details(did).await?; + let country = infer_country_of_living(ai, &details.display_name, &details.description).await?; + store_profile_details(db, did, &country).await?; + println!("Stored inferred country of living for {did}: {country}"); + Ok(()) +} + +async fn fetch_profile_details(did: &str) -> Result { + let client = AtpServiceClient::new(ReqwestClient::new("https://bsky.social".into())); + + let result = client + .service + .com + .atproto + .repo + .get_record(atrium_api::com::atproto::repo::get_record::Parameters { + collection: "app.bsky.actor.profile".to_owned(), + cid: None, + repo: did.to_owned(), + rkey: "self".to_owned(), + }) + .await?; + + let profile = match result.value { + atrium_api::records::Record::AppBskyActorProfile(profile) => profile, + _ => return Err(anyhow!("Big bad, no such profile")), + }; + + Ok(ProfileDetails { + display_name: profile.display_name.unwrap_or_else(String::new), + description: profile.description.unwrap_or_else(String::new), + }) +}