Use Arcs to pass stuff around to avoid dealing with lifetimes

And also implement proper language detection through lingua-rs,
because Bluesky's detection is really bad
This commit is contained in:
Aleksei Voronov 2023-09-21 10:36:47 +02:00
parent 9a2a88dc6b
commit f4ee482ce7
13 changed files with 1200 additions and 62 deletions

1070
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ ciborium = "0.2.1"
dotenv = "0.15.0" dotenv = "0.15.0"
futures = "0.3.28" futures = "0.3.28"
libipld-core = { version = "0.16.0", features = ["serde-codec"] } libipld-core = { version = "0.16.0", features = ["serde-codec"] }
lingua = "1.5.0"
once_cell = "1.18.0" once_cell = "1.18.0"
rs-car = "0.4.1" rs-car = "0.4.1"
scooby = "0.5.0" scooby = "0.5.0"

View File

@ -5,11 +5,10 @@ use std::collections::{HashMap, HashSet};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use once_cell::sync::Lazy;
use crate::services::database::{Database, Post}; use crate::services::database::{Database, Post};
use self::nederlandskie::Nederlandskie; pub use self::nederlandskie::Nederlandskie;
#[async_trait] #[async_trait]
pub trait Algo { pub trait Algo {
@ -24,22 +23,43 @@ pub trait Algo {
} }
pub type AnyAlgo = Box<dyn Algo + Sync + Send>; pub type AnyAlgo = Box<dyn Algo + Sync + Send>;
pub type AlgosMap = HashMap<&'static str, AnyAlgo>; type AlgosMap = HashMap<String, AnyAlgo>;
static ALL_ALGOS: Lazy<AlgosMap> = Lazy::new(|| { pub struct Algos {
let mut m = AlgosMap::new(); algos: AlgosMap,
m.insert("nederlandskie", Box::new(Nederlandskie));
m
});
pub fn iter_names() -> impl Iterator<Item = &'static str> {
ALL_ALGOS.keys().map(|s| *s)
} }
pub fn iter_all() -> impl Iterator<Item = &'static AnyAlgo> { impl Algos {
ALL_ALGOS.values() pub fn iter_names(&self) -> impl Iterator<Item = &str> {
self.algos.keys().map(String::as_str)
}
pub fn iter_all(&self) -> impl Iterator<Item = &AnyAlgo> {
self.algos.values()
}
pub fn get_by_name(&self, name: &str) -> Option<&AnyAlgo> {
self.algos.get(name)
}
} }
pub fn get_by_name(name: &str) -> Option<&'static AnyAlgo> { pub struct AlgosBuilder {
ALL_ALGOS.get(name) algos: AlgosMap,
}
impl AlgosBuilder {
pub fn new() -> Self {
Self {
algos: AlgosMap::new(),
}
}
pub fn add<T: Algo + Send + Sync + 'static>(mut self, name: &str, algo: T) -> Self {
self.algos.insert(name.to_owned(), Box::new(algo));
self
}
pub fn build(self) -> Algos {
Algos { algos: self.algos }
}
} }

View File

@ -1,14 +1,25 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use lingua::Language::Russian;
use lingua::LanguageDetector;
use super::Algo; use super::Algo;
use crate::services::{database::Post, Database}; use crate::services::{database::Post, Database};
pub struct Nederlandskie; pub struct Nederlandskie {
language_detector: Arc<LanguageDetector>,
}
impl Nederlandskie {
pub fn new(language_detector: Arc<LanguageDetector>) -> Self {
Self { language_detector }
}
}
/// An algorithm that serves posts written in Russian by people living in Netherlands /// An algorithm that serves posts written in Russian by people living in Netherlands
#[async_trait] #[async_trait]
@ -16,12 +27,10 @@ impl Algo for Nederlandskie {
fn should_index_post( fn should_index_post(
&self, &self,
_author_did: &str, _author_did: &str,
languages: &HashSet<String>, _languages: &HashSet<String>,
_text: &str, text: &str,
) -> bool { ) -> bool {
// BlueSky gets confused a lot about Russian vs Ukrainian, so skip posts self.language_detector.detect_language_of(text) == Some(Russian)
// that may be in Ukrainian regardless of whether Russian is in the list
languages.contains("ru") && !languages.contains("uk")
} }
async fn fetch_posts( async fn fetch_posts(

View File

@ -2,7 +2,6 @@ use anyhow::Result;
use dotenv::dotenv; use dotenv::dotenv;
use std::env; use std::env;
#[derive(Clone)]
pub struct Config { pub struct Config {
pub chat_gpt_api_key: String, pub chat_gpt_api_key: String,
pub database_url: String, pub database_url: String,

View File

@ -3,8 +3,13 @@ mod config;
mod processes; mod processes;
mod services; mod services;
use anyhow::Result; use std::sync::Arc;
use anyhow::Result;
use lingua::LanguageDetectorBuilder;
use crate::algos::AlgosBuilder;
use crate::algos::Nederlandskie;
use crate::config::Config; use crate::config::Config;
use crate::processes::FeedServer; use crate::processes::FeedServer;
use crate::processes::PostIndexer; use crate::processes::PostIndexer;
@ -15,15 +20,26 @@ use crate::services::AI;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let config = Config::load()?; let config = Arc::new(Config::load()?);
let ai = AI::new(&config.chat_gpt_api_key, "https://api.openai.com"); let ai = Arc::new(AI::new(&config.chat_gpt_api_key, "https://api.openai.com"));
let bluesky = Bluesky::new("https://bsky.social"); let bluesky = Arc::new(Bluesky::new("https://bsky.social"));
let database = Database::connect(&config.database_url).await?; let database = Arc::new(Database::connect(&config.database_url).await?);
let language_detector = Arc::new(
LanguageDetectorBuilder::from_all_languages()
.with_preloaded_language_models()
.build(),
);
let post_indexer = PostIndexer::new(&database, &bluesky); let algos = Arc::new(
let profile_classifier = ProfileClassifier::new(&database, &ai, &bluesky); AlgosBuilder::new()
let feed_server = FeedServer::new(&database, &config); .add("nederlandskie", Nederlandskie::new(language_detector))
.build(),
);
let post_indexer = PostIndexer::new(database.clone(), bluesky.clone(), algos.clone());
let profile_classifier = ProfileClassifier::new(database.clone(), ai.clone(), bluesky.clone());
let feed_server = FeedServer::new(database.clone(), config.clone(), algos.clone());
tokio::try_join!( tokio::try_join!(
post_indexer.start(), post_indexer.start(),

View File

@ -3,7 +3,6 @@ use atrium_api::app::bsky::feed::describe_feed_generator::{
}; };
use axum::{extract::State, Json}; use axum::{extract::State, Json};
use crate::algos;
use crate::processes::feed_server::state::FeedServerState; use crate::processes::feed_server::state::FeedServerState;
pub async fn describe_feed_generator( pub async fn describe_feed_generator(
@ -11,7 +10,9 @@ pub async fn describe_feed_generator(
) -> Json<FeedGeneratorDescription> { ) -> Json<FeedGeneratorDescription> {
Json(FeedGeneratorDescription { Json(FeedGeneratorDescription {
did: state.config.service_did.clone(), did: state.config.service_did.clone(),
feeds: algos::iter_names() feeds: state
.algos
.iter_names()
.map(|name| Feed { .map(|name| Feed {
uri: format!( uri: format!(
"at://{}/app.bsky.feed.generator/{}", "at://{}/app.bsky.feed.generator/{}",

View File

@ -7,14 +7,15 @@ use axum::extract::{Query, State};
use axum::Json; use axum::Json;
use chrono::{DateTime, TimeZone, Utc}; use chrono::{DateTime, TimeZone, Utc};
use crate::algos;
use crate::processes::feed_server::state::FeedServerState; use crate::processes::feed_server::state::FeedServerState;
pub async fn get_feed_skeleton( pub async fn get_feed_skeleton(
State(state): State<FeedServerState>, State(state): State<FeedServerState>,
query: Query<FeedSkeletonQuery>, query: Query<FeedSkeletonQuery>,
) -> Json<FeedSkeleton> { ) -> Json<FeedSkeleton> {
let algo = algos::get_by_name(&query.feed) let algo = state
.algos
.get_by_name(&query.feed)
.ok_or_else(|| anyhow!("Feed {} not found", query.feed)) .ok_or_else(|| anyhow!("Feed {} not found", query.feed))
.unwrap(); // TODO: handle error .unwrap(); // TODO: handle error

View File

@ -1,23 +1,30 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use axum::routing::get; use axum::routing::get;
use axum::{Router, Server}; use axum::{Router, Server};
use crate::algos::Algos;
use crate::config::Config; use crate::config::Config;
use crate::services::Database; use crate::services::Database;
use super::endpoints::{describe_feed_generator, did_json, get_feed_skeleton, root}; use super::endpoints::{describe_feed_generator, did_json, get_feed_skeleton, root};
use super::state::FeedServerState; use super::state::FeedServerState;
pub struct FeedServer<'a> { pub struct FeedServer {
database: &'a Database, database: Arc<Database>,
config: &'a Config, config: Arc<Config>,
algos: Arc<Algos>,
} }
impl<'a> FeedServer<'a> { impl FeedServer {
pub fn new(database: &'a Database, config: &'a Config) -> Self { pub fn new(database: Arc<Database>, config: Arc<Config>, algos: Arc<Algos>) -> Self {
Self { database, config } Self {
database,
config,
algos,
}
} }
pub async fn serve(self) -> Result<()> { pub async fn serve(self) -> Result<()> {
@ -33,8 +40,9 @@ impl<'a> FeedServer<'a> {
get(get_feed_skeleton), get(get_feed_skeleton),
) )
.with_state(FeedServerState { .with_state(FeedServerState {
database: self.database.clone(), database: self.database,
config: self.config.clone(), config: self.config,
algos: self.algos,
}); });
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View File

@ -1,8 +1,12 @@
use std::sync::Arc;
use crate::algos::Algos;
use crate::config::Config; use crate::config::Config;
use crate::services::Database; use crate::services::Database;
#[derive(Clone)] #[derive(Clone)]
pub struct FeedServerState { pub struct FeedServerState {
pub database: Database, pub database: Arc<Database>,
pub config: Config, pub config: Arc<Config>,
pub algos: Arc<Algos>,
} }

View File

@ -1,29 +1,36 @@
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use crate::algos; use crate::algos::Algos;
use crate::services::bluesky::{Bluesky, Operation, OperationProcessor}; use crate::services::bluesky::{Bluesky, Operation, OperationProcessor};
use crate::services::Database; use crate::services::Database;
pub struct PostIndexer<'a> { pub struct PostIndexer {
database: &'a Database, database: Arc<Database>,
bluesky: &'a Bluesky, bluesky: Arc<Bluesky>,
algos: Arc<Algos>,
} }
impl<'a> PostIndexer<'a> { impl PostIndexer {
pub fn new(database: &'a Database, bluesky: &'a Bluesky) -> Self { pub fn new(database: Arc<Database>, bluesky: Arc<Bluesky>, algos: Arc<Algos>) -> Self {
Self { database, bluesky } Self {
database,
bluesky,
algos,
}
} }
} }
impl<'a> PostIndexer<'a> { impl PostIndexer {
pub async fn start(&self) -> Result<()> { pub async fn start(&self) -> Result<()> {
Ok(self.bluesky.subscribe_to_operations(self).await?) Ok(self.bluesky.subscribe_to_operations(self).await?)
} }
} }
#[async_trait] #[async_trait]
impl<'a> OperationProcessor for PostIndexer<'a> { impl OperationProcessor for PostIndexer {
async fn process_operation(&self, operation: &Operation) -> Result<()> { async fn process_operation(&self, operation: &Operation) -> Result<()> {
match operation { match operation {
Operation::CreatePost { Operation::CreatePost {
@ -33,7 +40,11 @@ impl<'a> OperationProcessor for PostIndexer<'a> {
languages, languages,
text, text,
} => { } => {
if algos::iter_all().any(|a| a.should_index_post(author_did, languages, text)) { if self
.algos
.iter_all()
.any(|a| a.should_index_post(author_did, languages, text))
{
println!("received insertable post from {author_did}: {text}"); println!("received insertable post from {author_did}: {text}");
self.database self.database

View File

@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::Result; use anyhow::Result;
@ -6,14 +7,14 @@ use crate::services::Bluesky;
use crate::services::Database; use crate::services::Database;
use crate::services::AI; use crate::services::AI;
pub struct ProfileClassifier<'a> { pub struct ProfileClassifier {
database: &'a Database, database: Arc<Database>,
ai: &'a AI, ai: Arc<AI>,
bluesky: &'a Bluesky, bluesky: Arc<Bluesky>,
} }
impl<'a> ProfileClassifier<'a> { impl ProfileClassifier {
pub fn new(database: &'a Database, ai: &'a AI, bluesky: &'a Bluesky) -> Self { pub fn new(database: Arc<Database>, ai: Arc<AI>, bluesky: Arc<Bluesky>) -> Self {
Self { Self {
database, database,
ai, ai,

View File

@ -24,7 +24,6 @@ pub struct SubscriptionState {
cursor: i64, cursor: i64,
} }
#[derive(Clone)]
pub struct Database { pub struct Database {
connection_pool: PgPool, connection_pool: PgPool,
} }