diff --git a/.env.example b/.env.example index 4be004e..c7a72c6 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,3 @@ CHAT_GPT_API_KEY="fake-chat-gpt-key" DATABASE_URL="postgres://postgres:password@localhost/nederlandskie" +HOSTNAME="..." \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 62288e5..33ac4e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,6 +132,55 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -1125,6 +1174,12 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "matchit" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef" + [[package]] name = "md-5" version = "0.10.5" @@ -1242,6 +1297,7 @@ dependencies = [ "async-trait", "atrium-api", "atrium-xrpc", + "axum", "chat-gpt-lib-rs", "chrono", "ciborium", @@ -1250,6 +1306,7 @@ dependencies = [ "libipld-core", "rs-car", "scooby", + "serde", "serde_ipld_dagcbor", "sqlx", "tokio", @@ -1702,6 +1759,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + [[package]] name = "ryu" version = "1.0.15" @@ -1813,6 +1876,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4beec8bce849d58d06238cb50db2e1c417cfeafa4c63f692b15c82b7c80f8335" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_qs" version = "0.12.0" @@ -2196,6 +2269,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "synstructure" version = "0.12.6" @@ -2364,6 +2443,28 @@ dependencies = [ "serde", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + [[package]] name = "tower-service" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index 79ca759..0e26593 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ anyhow = "1.0.75" async-trait = "0.1.73" atrium-api = "0.6.0" atrium-xrpc = "0.4.0" +axum = "0.6.20" chat-gpt-lib-rs = "0.2.1" chrono = "0.4.29" ciborium = "0.2.1" @@ -18,6 +19,7 @@ futures = "0.3.28" libipld-core = { version = "0.16.0", features = ["serde-codec"] } rs-car = "0.4.1" scooby = "0.5.0" +serde = "1.0.188" serde_ipld_dagcbor = "0.4.1" sqlx = { version = "0.7.1", default-features = false, features = ["postgres", "runtime-tokio-native-tls", "chrono"] } tokio = { version = "1.32.0", features = ["full"] } diff --git a/README.md b/README.md index e150250..7ae26f6 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Copy `.env.example` into `.env` and set up the environment variables within: - `CHAT_GPT_API_KEY` for your ChatGPT key - `DATABASE_URL` for PostgreSQL credentials +- `HOSTNAME` to the hostname of where you intend to host the feed ## Running diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..f66d103 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,26 @@ +use anyhow::Result; +use dotenv::dotenv; +use std::env; + +#[derive(Clone)] +pub struct Config { + pub chat_gpt_api_key: String, + pub database_url: String, + pub service_did: String, + pub publisher_did: String, + pub hostname: String, +} + +impl Config { + pub fn load() -> Result { + dotenv()?; + + Ok(Self { + chat_gpt_api_key: env::var("CHAT_GPT_API_KEY")?, + database_url: env::var("DATABASE_URL")?, + hostname: env::var("HOSTNAME")?, + service_did: format!("did:web:{}", env::var("HOSTNAME")?), + publisher_did: "".to_owned(), // TODO + }) + } +} diff --git a/src/main.rs b/src/main.rs index 4cf6b64..6e6b12b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,33 +1,17 @@ +mod config; mod processes; mod services; -use std::env; - use anyhow::Result; -use dotenv::dotenv; +use crate::config::Config; +use crate::processes::feed_server::FeedServer; use crate::processes::post_saver::PostSaver; use crate::processes::profile_classifier::ProfileClassifier; use crate::services::ai::AI; use crate::services::bluesky::Bluesky; use crate::services::database::Database; -struct Config { - chat_gpt_api_key: String, - database_url: String, -} - -impl Config { - fn load() -> Result { - dotenv()?; - - Ok(Self { - chat_gpt_api_key: env::var("CHAT_GPT_API_KEY")?, - database_url: env::var("DATABASE_URL")?, - }) - } -} - #[tokio::main] async fn main() -> Result<()> { let config = Config::load()?; @@ -38,8 +22,13 @@ async fn main() -> Result<()> { let post_saver = PostSaver::new(&database, &bluesky); let profile_classifier = ProfileClassifier::new(&database, &ai, &bluesky); + let feed_server = FeedServer::new(&database, &config); - tokio::try_join!(post_saver.start(), profile_classifier.start())?; + tokio::try_join!( + post_saver.start(), + profile_classifier.start(), + feed_server.serve(), + )?; Ok(()) } diff --git a/src/processes/feed_server.rs b/src/processes/feed_server.rs new file mode 100644 index 0000000..86187ee --- /dev/null +++ b/src/processes/feed_server.rs @@ -0,0 +1,5 @@ +mod endpoints; +mod server; +mod state; + +pub use server::FeedServer; diff --git a/src/processes/feed_server/endpoints.rs b/src/processes/feed_server/endpoints.rs new file mode 100644 index 0000000..e738ffe --- /dev/null +++ b/src/processes/feed_server/endpoints.rs @@ -0,0 +1,9 @@ +mod describe_feed_generator; +mod did_json; +mod get_feed_skeleton; +mod root; + +pub use describe_feed_generator::describe_feed_generator; +pub use did_json::did_json; +pub use get_feed_skeleton::get_feed_skeleton; +pub use root::root; diff --git a/src/processes/feed_server/endpoints/describe_feed_generator.rs b/src/processes/feed_server/endpoints/describe_feed_generator.rs new file mode 100644 index 0000000..582eb69 --- /dev/null +++ b/src/processes/feed_server/endpoints/describe_feed_generator.rs @@ -0,0 +1,18 @@ +use atrium_api::app::bsky::feed::describe_feed_generator::{ + Feed, Output as FeedGeneratorDescription, +}; +use axum::{extract::State, Json}; + +use crate::processes::feed_server::state::FeedServerState; + +pub async fn describe_feed_generator( + State(state): State, +) -> Json { + Json(FeedGeneratorDescription { + did: state.config.service_did.clone(), + feeds: vec![Feed { + uri: format!("at://{}/app.bsky.feed.generator/{}", state.config.publisher_did, "nederlandskie"), + }], + links: None, + }) +} diff --git a/src/processes/feed_server/endpoints/did_json.rs b/src/processes/feed_server/endpoints/did_json.rs new file mode 100644 index 0000000..f8eb5fa --- /dev/null +++ b/src/processes/feed_server/endpoints/did_json.rs @@ -0,0 +1,32 @@ +use axum::{extract::State, Json}; +use serde::Serialize; + +use crate::processes::feed_server::state::FeedServerState; + +#[derive(Serialize)] +pub struct Did { + #[serde(rename = "@context")] + context: Vec, + id: String, + service: Vec, +} + +#[derive(Serialize)] +pub struct Service { + id: String, + #[serde(rename = "type")] + type_: String, + service_endpoint: String, +} + +pub async fn did_json(State(state): State) -> Json { + Json(Did { + context: vec!["https://www.w3.org/ns/did/v1".to_owned()], + id: state.config.service_did.clone(), + service: vec![Service { + id: "#bsky_fg".to_owned(), + type_: "BskyFeedGenerator".to_owned(), + service_endpoint: format!("https://{}", state.config.hostname), + }], + }) +} diff --git a/src/processes/feed_server/endpoints/get_feed_skeleton.rs b/src/processes/feed_server/endpoints/get_feed_skeleton.rs new file mode 100644 index 0000000..8acaa25 --- /dev/null +++ b/src/processes/feed_server/endpoints/get_feed_skeleton.rs @@ -0,0 +1,61 @@ +use anyhow::{anyhow, Result}; +use atrium_api::app::bsky::feed::defs::SkeletonFeedPost; +use atrium_api::app::bsky::feed::get_feed_skeleton::{ + Output as FeedSkeleton, Parameters as FeedSkeletonQuery, +}; +use axum::extract::{Query, State}; +use axum::Json; +use chrono::{DateTime, TimeZone, Utc}; + +use crate::processes::feed_server::state::FeedServerState; + +pub async fn get_feed_skeleton( + State(state): State, + query: Query, +) -> Json { + let limit = query.limit.unwrap_or(20) as usize; + let earlier_than = query + .cursor + .as_deref() + .map(parse_cursor) + .transpose() + .unwrap(); // TODO: handle error + + let posts = state + .database + .fetch_posts_by_authors_country("ru", limit, earlier_than) + .await + .unwrap(); + + let feed = posts + .iter() + .map(|p| SkeletonFeedPost { + post: p.uri.clone(), + reason: None, + }) + .collect(); + + let cursor = posts.last().map(|p| make_cursor(&p.indexed_at, &p.cid)); + + Json(FeedSkeleton { cursor, feed }) +} + +fn make_cursor(date: &DateTime, cid: &str) -> String { + format!("{}::{}", date.timestamp() * 1000, cid) +} + +fn parse_cursor(cursor: &str) -> Result<(DateTime, &str)> { + let mut parts = cursor.split("::"); + + let indexed_at = parts.next().ok_or_else(|| anyhow!("Malformed cursor"))?; + let cid = parts.next().ok_or_else(|| anyhow!("Malformed cursor"))?; + + if parts.next().is_some() { + return Err(anyhow!("Malformed cursor")); + } + + let indexed_at: i64 = indexed_at.parse()?; + let indexed_at = Utc.timestamp_opt(indexed_at / 1000, 0).unwrap(); // TODO: handle error + + Ok((indexed_at, cid)) +} diff --git a/src/processes/feed_server/endpoints/root.rs b/src/processes/feed_server/endpoints/root.rs new file mode 100644 index 0000000..8f5f0d9 --- /dev/null +++ b/src/processes/feed_server/endpoints/root.rs @@ -0,0 +1,3 @@ +pub async fn root() -> &'static str { + "Hello, World!" +} diff --git a/src/processes/feed_server/server.rs b/src/processes/feed_server/server.rs new file mode 100644 index 0000000..ca0f85d --- /dev/null +++ b/src/processes/feed_server/server.rs @@ -0,0 +1,44 @@ +use std::net::SocketAddr; + +use anyhow::Result; +use axum::routing::get; +use axum::{Router, Server}; + +use crate::config::Config; +use crate::services::database::Database; + +use super::endpoints::{describe_feed_generator, did_json, get_feed_skeleton, root}; +use super::state::FeedServerState; + +pub struct FeedServer<'a> { + database: &'a Database, + config: &'a Config, +} + +impl<'a> FeedServer<'a> { + pub fn new(database: &'a Database, config: &'a Config) -> Self { + Self { database, config } + } + + pub async fn serve(self) -> Result<()> { + let app = Router::new() + .route("/", get(root)) + .route("/.well-known/did.json", get(did_json)) + .route( + "/xrpc/app.bsky.feed.describeFeedGenerator", + get(describe_feed_generator), + ) + .route( + "/xrpc/app.bsky.feed.getFeedSkeleton", + get(get_feed_skeleton), + ) + .with_state(FeedServerState { + database: self.database.clone(), + config: self.config.clone(), + }); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + Server::bind(&addr).serve(app.into_make_service()).await?; + Ok(()) + } +} diff --git a/src/processes/feed_server/state.rs b/src/processes/feed_server/state.rs new file mode 100644 index 0000000..0100262 --- /dev/null +++ b/src/processes/feed_server/state.rs @@ -0,0 +1,8 @@ +use crate::config::Config; +use crate::services::database::Database; + +#[derive(Clone)] +pub struct FeedServerState { + pub database: Database, + pub config: Config, +} diff --git a/src/processes/mod.rs b/src/processes/mod.rs index a73ea87..87e77d3 100644 --- a/src/processes/mod.rs +++ b/src/processes/mod.rs @@ -1,2 +1,3 @@ +pub mod feed_server; pub mod post_saver; pub mod profile_classifier; diff --git a/src/services/ai.rs b/src/services/ai.rs index 9809b0b..2c23955 100644 --- a/src/services/ai.rs +++ b/src/services/ai.rs @@ -23,7 +23,7 @@ impl AI { Message { role: Role::System, // TODO: Lol, prompt injection much? - content: "You are a tool that attempts to guess where a person is likely to be from based on their name and short bio. Please respond with two-letter country code only. Use lowercase letters.".to_string(), + content: "You are a tool that attempts to guess where a person is likely to be from based on their name and short bio. Please respond with two-letter country code only. If unable to determine, say xx.".to_string(), }, Message { role: Role::User, @@ -36,6 +36,6 @@ impl AI { let response = self.chat_gpt_client.chat(chat_input).await?; // TODO: Error handling? - return Ok(response.choices[0].message.content.clone()); + return Ok(response.choices[0].message.content.to_lowercase()); } } diff --git a/src/services/database.rs b/src/services/database.rs index e9efe1f..cf9de42 100644 --- a/src/services/database.rs +++ b/src/services/database.rs @@ -1,15 +1,15 @@ use anyhow::Result; use chrono::{DateTime, Utc}; -use scooby::postgres::{insert_into, select, update, Parameters}; +use scooby::postgres::{insert_into, select, update, Joinable, Orderable, Parameters, Aliasable}; use sqlx::postgres::{PgPool, PgPoolOptions, PgRow}; use sqlx::query; use sqlx::Row; pub struct Post { - indexed_at: DateTime, - author_did: String, - cid: String, - uri: String, + pub indexed_at: DateTime, + pub author_did: String, + pub cid: String, + pub uri: String, } pub struct Profile { @@ -24,6 +24,7 @@ pub struct SubscriptionState { cursor: i64, } +#[derive(Clone)] pub struct Database { connection_pool: PgPool, } @@ -52,6 +53,49 @@ impl Database { .map(|_| ())?) } + pub async fn fetch_posts_by_authors_country( + &self, + author_country: &str, + limit: usize, + earlier_than: Option<(DateTime, &str)>, + ) -> Result> { + let mut params = Parameters::new(); + let mut sql_builder = select(("p.indexed_at", "p.author_did", "p.cid", "p.uri")) + .from( + "Post".as_("p") + .inner_join("Profile".as_("pr")) + .on("pr.did = p.author_did"), + ) + .where_(format!("pr.likely_country_of_living = {}", params.next())) + .order_by(("p.indexed_at".desc(), "p.cid".desc())) + .limit(limit); + + if earlier_than.is_some() { + sql_builder = sql_builder + .where_(format!("p.indexed_at <= {}", params.next())) + .where_(format!("p.cid < {}", params.next())); + } + + let sql_string = sql_builder.to_string(); + + let mut query_object = query(&sql_string) + .bind(author_country); + + if let Some((last_indexed_at, last_cid)) = earlier_than { + query_object = query_object.bind(last_indexed_at).bind(last_cid); + } + + Ok(query_object + .map(|r: PgRow| Post { + indexed_at: r.get("indexed_at"), + author_did: r.get("author_did"), + cid: r.get("cid"), + uri: r.get("uri"), + }) + .fetch_all(&self.connection_pool) + .await?) + } + pub async fn insert_profile_if_it_doesnt_exist(&self, did: &str) -> Result { let mut params = Parameters::new();