Serve the feed, according to all the Atproto endpoints

This commit is contained in:
Aleksei Voronov 2023-09-16 17:13:57 +02:00
parent c2899951f6
commit b4250e12cd
17 changed files with 372 additions and 27 deletions

View File

@ -1,2 +1,3 @@
CHAT_GPT_API_KEY="fake-chat-gpt-key"
DATABASE_URL="postgres://postgres:password@localhost/nederlandskie"
HOSTNAME="..."

101
Cargo.lock generated
View File

@ -132,6 +132,55 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [
"async-trait",
"axum-core",
"bitflags 1.3.2",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.69"
@ -1125,6 +1174,12 @@ version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "matchit"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef"
[[package]]
name = "md-5"
version = "0.10.5"
@ -1242,6 +1297,7 @@ dependencies = [
"async-trait",
"atrium-api",
"atrium-xrpc",
"axum",
"chat-gpt-lib-rs",
"chrono",
"ciborium",
@ -1250,6 +1306,7 @@ dependencies = [
"libipld-core",
"rs-car",
"scooby",
"serde",
"serde_ipld_dagcbor",
"sqlx",
"tokio",
@ -1702,6 +1759,12 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "rustversion"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
[[package]]
name = "ryu"
version = "1.0.15"
@ -1813,6 +1876,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4beec8bce849d58d06238cb50db2e1c417cfeafa4c63f692b15c82b7c80f8335"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_qs"
version = "0.12.0"
@ -2196,6 +2269,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "synstructure"
version = "0.12.6"
@ -2364,6 +2443,28 @@ dependencies = [
"serde",
]
[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-layer"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]]
name = "tower-service"
version = "0.3.2"

View File

@ -10,6 +10,7 @@ anyhow = "1.0.75"
async-trait = "0.1.73"
atrium-api = "0.6.0"
atrium-xrpc = "0.4.0"
axum = "0.6.20"
chat-gpt-lib-rs = "0.2.1"
chrono = "0.4.29"
ciborium = "0.2.1"
@ -18,6 +19,7 @@ futures = "0.3.28"
libipld-core = { version = "0.16.0", features = ["serde-codec"] }
rs-car = "0.4.1"
scooby = "0.5.0"
serde = "1.0.188"
serde_ipld_dagcbor = "0.4.1"
sqlx = { version = "0.7.1", default-features = false, features = ["postgres", "runtime-tokio-native-tls", "chrono"] }
tokio = { version = "1.32.0", features = ["full"] }

View File

@ -21,6 +21,7 @@ Copy `.env.example` into `.env` and set up the environment variables within:
- `CHAT_GPT_API_KEY` for your ChatGPT key
- `DATABASE_URL` for PostgreSQL credentials
- `HOSTNAME` to the hostname of where you intend to host the feed
## Running

26
src/config.rs Normal file
View File

@ -0,0 +1,26 @@
use anyhow::Result;
use dotenv::dotenv;
use std::env;
#[derive(Clone)]
pub struct Config {
pub chat_gpt_api_key: String,
pub database_url: String,
pub service_did: String,
pub publisher_did: String,
pub hostname: String,
}
impl Config {
pub fn load() -> Result<Self> {
dotenv()?;
Ok(Self {
chat_gpt_api_key: env::var("CHAT_GPT_API_KEY")?,
database_url: env::var("DATABASE_URL")?,
hostname: env::var("HOSTNAME")?,
service_did: format!("did:web:{}", env::var("HOSTNAME")?),
publisher_did: "".to_owned(), // TODO
})
}
}

View File

@ -1,33 +1,17 @@
mod config;
mod processes;
mod services;
use std::env;
use anyhow::Result;
use dotenv::dotenv;
use crate::config::Config;
use crate::processes::feed_server::FeedServer;
use crate::processes::post_saver::PostSaver;
use crate::processes::profile_classifier::ProfileClassifier;
use crate::services::ai::AI;
use crate::services::bluesky::Bluesky;
use crate::services::database::Database;
struct Config {
chat_gpt_api_key: String,
database_url: String,
}
impl Config {
fn load() -> Result<Self> {
dotenv()?;
Ok(Self {
chat_gpt_api_key: env::var("CHAT_GPT_API_KEY")?,
database_url: env::var("DATABASE_URL")?,
})
}
}
#[tokio::main]
async fn main() -> Result<()> {
let config = Config::load()?;
@ -38,8 +22,13 @@ async fn main() -> Result<()> {
let post_saver = PostSaver::new(&database, &bluesky);
let profile_classifier = ProfileClassifier::new(&database, &ai, &bluesky);
let feed_server = FeedServer::new(&database, &config);
tokio::try_join!(post_saver.start(), profile_classifier.start())?;
tokio::try_join!(
post_saver.start(),
profile_classifier.start(),
feed_server.serve(),
)?;
Ok(())
}

View File

@ -0,0 +1,5 @@
mod endpoints;
mod server;
mod state;
pub use server::FeedServer;

View File

@ -0,0 +1,9 @@
mod describe_feed_generator;
mod did_json;
mod get_feed_skeleton;
mod root;
pub use describe_feed_generator::describe_feed_generator;
pub use did_json::did_json;
pub use get_feed_skeleton::get_feed_skeleton;
pub use root::root;

View File

@ -0,0 +1,18 @@
use atrium_api::app::bsky::feed::describe_feed_generator::{
Feed, Output as FeedGeneratorDescription,
};
use axum::{extract::State, Json};
use crate::processes::feed_server::state::FeedServerState;
pub async fn describe_feed_generator(
State(state): State<FeedServerState>,
) -> Json<FeedGeneratorDescription> {
Json(FeedGeneratorDescription {
did: state.config.service_did.clone(),
feeds: vec![Feed {
uri: format!("at://{}/app.bsky.feed.generator/{}", state.config.publisher_did, "nederlandskie"),
}],
links: None,
})
}

View File

@ -0,0 +1,32 @@
use axum::{extract::State, Json};
use serde::Serialize;
use crate::processes::feed_server::state::FeedServerState;
#[derive(Serialize)]
pub struct Did {
#[serde(rename = "@context")]
context: Vec<String>,
id: String,
service: Vec<Service>,
}
#[derive(Serialize)]
pub struct Service {
id: String,
#[serde(rename = "type")]
type_: String,
service_endpoint: String,
}
pub async fn did_json(State(state): State<FeedServerState>) -> Json<Did> {
Json(Did {
context: vec!["https://www.w3.org/ns/did/v1".to_owned()],
id: state.config.service_did.clone(),
service: vec![Service {
id: "#bsky_fg".to_owned(),
type_: "BskyFeedGenerator".to_owned(),
service_endpoint: format!("https://{}", state.config.hostname),
}],
})
}

View File

@ -0,0 +1,61 @@
use anyhow::{anyhow, Result};
use atrium_api::app::bsky::feed::defs::SkeletonFeedPost;
use atrium_api::app::bsky::feed::get_feed_skeleton::{
Output as FeedSkeleton, Parameters as FeedSkeletonQuery,
};
use axum::extract::{Query, State};
use axum::Json;
use chrono::{DateTime, TimeZone, Utc};
use crate::processes::feed_server::state::FeedServerState;
pub async fn get_feed_skeleton(
State(state): State<FeedServerState>,
query: Query<FeedSkeletonQuery>,
) -> Json<FeedSkeleton> {
let limit = query.limit.unwrap_or(20) as usize;
let earlier_than = query
.cursor
.as_deref()
.map(parse_cursor)
.transpose()
.unwrap(); // TODO: handle error
let posts = state
.database
.fetch_posts_by_authors_country("ru", limit, earlier_than)
.await
.unwrap();
let feed = posts
.iter()
.map(|p| SkeletonFeedPost {
post: p.uri.clone(),
reason: None,
})
.collect();
let cursor = posts.last().map(|p| make_cursor(&p.indexed_at, &p.cid));
Json(FeedSkeleton { cursor, feed })
}
fn make_cursor(date: &DateTime<Utc>, cid: &str) -> String {
format!("{}::{}", date.timestamp() * 1000, cid)
}
fn parse_cursor(cursor: &str) -> Result<(DateTime<Utc>, &str)> {
let mut parts = cursor.split("::");
let indexed_at = parts.next().ok_or_else(|| anyhow!("Malformed cursor"))?;
let cid = parts.next().ok_or_else(|| anyhow!("Malformed cursor"))?;
if parts.next().is_some() {
return Err(anyhow!("Malformed cursor"));
}
let indexed_at: i64 = indexed_at.parse()?;
let indexed_at = Utc.timestamp_opt(indexed_at / 1000, 0).unwrap(); // TODO: handle error
Ok((indexed_at, cid))
}

View File

@ -0,0 +1,3 @@
pub async fn root() -> &'static str {
"Hello, World!"
}

View File

@ -0,0 +1,44 @@
use std::net::SocketAddr;
use anyhow::Result;
use axum::routing::get;
use axum::{Router, Server};
use crate::config::Config;
use crate::services::database::Database;
use super::endpoints::{describe_feed_generator, did_json, get_feed_skeleton, root};
use super::state::FeedServerState;
pub struct FeedServer<'a> {
database: &'a Database,
config: &'a Config,
}
impl<'a> FeedServer<'a> {
pub fn new(database: &'a Database, config: &'a Config) -> Self {
Self { database, config }
}
pub async fn serve(self) -> Result<()> {
let app = Router::new()
.route("/", get(root))
.route("/.well-known/did.json", get(did_json))
.route(
"/xrpc/app.bsky.feed.describeFeedGenerator",
get(describe_feed_generator),
)
.route(
"/xrpc/app.bsky.feed.getFeedSkeleton",
get(get_feed_skeleton),
)
.with_state(FeedServerState {
database: self.database.clone(),
config: self.config.clone(),
});
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
Server::bind(&addr).serve(app.into_make_service()).await?;
Ok(())
}
}

View File

@ -0,0 +1,8 @@
use crate::config::Config;
use crate::services::database::Database;
#[derive(Clone)]
pub struct FeedServerState {
pub database: Database,
pub config: Config,
}

View File

@ -1,2 +1,3 @@
pub mod feed_server;
pub mod post_saver;
pub mod profile_classifier;

View File

@ -23,7 +23,7 @@ impl AI {
Message {
role: Role::System,
// TODO: Lol, prompt injection much?
content: "You are a tool that attempts to guess where a person is likely to be from based on their name and short bio. Please respond with two-letter country code only. Use lowercase letters.".to_string(),
content: "You are a tool that attempts to guess where a person is likely to be from based on their name and short bio. Please respond with two-letter country code only. If unable to determine, say xx.".to_string(),
},
Message {
role: Role::User,
@ -36,6 +36,6 @@ impl AI {
let response = self.chat_gpt_client.chat(chat_input).await?;
// TODO: Error handling?
return Ok(response.choices[0].message.content.clone());
return Ok(response.choices[0].message.content.to_lowercase());
}
}

View File

@ -1,15 +1,15 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
use scooby::postgres::{insert_into, select, update, Parameters};
use scooby::postgres::{insert_into, select, update, Joinable, Orderable, Parameters, Aliasable};
use sqlx::postgres::{PgPool, PgPoolOptions, PgRow};
use sqlx::query;
use sqlx::Row;
pub struct Post {
indexed_at: DateTime<Utc>,
author_did: String,
cid: String,
uri: String,
pub indexed_at: DateTime<Utc>,
pub author_did: String,
pub cid: String,
pub uri: String,
}
pub struct Profile {
@ -24,6 +24,7 @@ pub struct SubscriptionState {
cursor: i64,
}
#[derive(Clone)]
pub struct Database {
connection_pool: PgPool,
}
@ -52,6 +53,49 @@ impl Database {
.map(|_| ())?)
}
pub async fn fetch_posts_by_authors_country(
&self,
author_country: &str,
limit: usize,
earlier_than: Option<(DateTime<Utc>, &str)>,
) -> Result<Vec<Post>> {
let mut params = Parameters::new();
let mut sql_builder = select(("p.indexed_at", "p.author_did", "p.cid", "p.uri"))
.from(
"Post".as_("p")
.inner_join("Profile".as_("pr"))
.on("pr.did = p.author_did"),
)
.where_(format!("pr.likely_country_of_living = {}", params.next()))
.order_by(("p.indexed_at".desc(), "p.cid".desc()))
.limit(limit);
if earlier_than.is_some() {
sql_builder = sql_builder
.where_(format!("p.indexed_at <= {}", params.next()))
.where_(format!("p.cid < {}", params.next()));
}
let sql_string = sql_builder.to_string();
let mut query_object = query(&sql_string)
.bind(author_country);
if let Some((last_indexed_at, last_cid)) = earlier_than {
query_object = query_object.bind(last_indexed_at).bind(last_cid);
}
Ok(query_object
.map(|r: PgRow| Post {
indexed_at: r.get("indexed_at"),
author_did: r.get("author_did"),
cid: r.get("cid"),
uri: r.get("uri"),
})
.fetch_all(&self.connection_pool)
.await?)
}
pub async fn insert_profile_if_it_doesnt_exist(&self, did: &str) -> Result<bool> {
let mut params = Parameters::new();