Use Arcs to pass stuff around to avoid dealing with lifetimes

And also implement proper language detection through lingua-rs,
because Bluesky's detection is really bad
This commit is contained in:
Aleksei Voronov 2023-09-21 10:36:47 +02:00
parent 9a2a88dc6b
commit f4ee482ce7
13 changed files with 1200 additions and 62 deletions

1070
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ ciborium = "0.2.1"
dotenv = "0.15.0"
futures = "0.3.28"
libipld-core = { version = "0.16.0", features = ["serde-codec"] }
lingua = "1.5.0"
once_cell = "1.18.0"
rs-car = "0.4.1"
scooby = "0.5.0"

View File

@ -5,11 +5,10 @@ use std::collections::{HashMap, HashSet};
use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use once_cell::sync::Lazy;
use crate::services::database::{Database, Post};
use self::nederlandskie::Nederlandskie;
pub use self::nederlandskie::Nederlandskie;
#[async_trait]
pub trait Algo {
@ -24,22 +23,43 @@ pub trait Algo {
}
pub type AnyAlgo = Box<dyn Algo + Sync + Send>;
pub type AlgosMap = HashMap<&'static str, AnyAlgo>;
type AlgosMap = HashMap<String, AnyAlgo>;
static ALL_ALGOS: Lazy<AlgosMap> = Lazy::new(|| {
let mut m = AlgosMap::new();
m.insert("nederlandskie", Box::new(Nederlandskie));
m
});
pub fn iter_names() -> impl Iterator<Item = &'static str> {
ALL_ALGOS.keys().map(|s| *s)
pub struct Algos {
algos: AlgosMap,
}
pub fn iter_all() -> impl Iterator<Item = &'static AnyAlgo> {
ALL_ALGOS.values()
impl Algos {
pub fn iter_names(&self) -> impl Iterator<Item = &str> {
self.algos.keys().map(String::as_str)
}
pub fn get_by_name(name: &str) -> Option<&'static AnyAlgo> {
ALL_ALGOS.get(name)
pub fn iter_all(&self) -> impl Iterator<Item = &AnyAlgo> {
self.algos.values()
}
pub fn get_by_name(&self, name: &str) -> Option<&AnyAlgo> {
self.algos.get(name)
}
}
pub struct AlgosBuilder {
algos: AlgosMap,
}
impl AlgosBuilder {
pub fn new() -> Self {
Self {
algos: AlgosMap::new(),
}
}
pub fn add<T: Algo + Send + Sync + 'static>(mut self, name: &str, algo: T) -> Self {
self.algos.insert(name.to_owned(), Box::new(algo));
self
}
pub fn build(self) -> Algos {
Algos { algos: self.algos }
}
}

View File

@ -1,14 +1,25 @@
use std::collections::HashSet;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use lingua::Language::Russian;
use lingua::LanguageDetector;
use super::Algo;
use crate::services::{database::Post, Database};
pub struct Nederlandskie;
pub struct Nederlandskie {
language_detector: Arc<LanguageDetector>,
}
impl Nederlandskie {
pub fn new(language_detector: Arc<LanguageDetector>) -> Self {
Self { language_detector }
}
}
/// An algorithm that serves posts written in Russian by people living in Netherlands
#[async_trait]
@ -16,12 +27,10 @@ impl Algo for Nederlandskie {
fn should_index_post(
&self,
_author_did: &str,
languages: &HashSet<String>,
_text: &str,
_languages: &HashSet<String>,
text: &str,
) -> bool {
// BlueSky gets confused a lot about Russian vs Ukrainian, so skip posts
// that may be in Ukrainian regardless of whether Russian is in the list
languages.contains("ru") && !languages.contains("uk")
self.language_detector.detect_language_of(text) == Some(Russian)
}
async fn fetch_posts(

View File

@ -2,7 +2,6 @@ use anyhow::Result;
use dotenv::dotenv;
use std::env;
#[derive(Clone)]
pub struct Config {
pub chat_gpt_api_key: String,
pub database_url: String,

View File

@ -3,8 +3,13 @@ mod config;
mod processes;
mod services;
use anyhow::Result;
use std::sync::Arc;
use anyhow::Result;
use lingua::LanguageDetectorBuilder;
use crate::algos::AlgosBuilder;
use crate::algos::Nederlandskie;
use crate::config::Config;
use crate::processes::FeedServer;
use crate::processes::PostIndexer;
@ -15,15 +20,26 @@ use crate::services::AI;
#[tokio::main]
async fn main() -> Result<()> {
let config = Config::load()?;
let config = Arc::new(Config::load()?);
let ai = AI::new(&config.chat_gpt_api_key, "https://api.openai.com");
let bluesky = Bluesky::new("https://bsky.social");
let database = Database::connect(&config.database_url).await?;
let ai = Arc::new(AI::new(&config.chat_gpt_api_key, "https://api.openai.com"));
let bluesky = Arc::new(Bluesky::new("https://bsky.social"));
let database = Arc::new(Database::connect(&config.database_url).await?);
let language_detector = Arc::new(
LanguageDetectorBuilder::from_all_languages()
.with_preloaded_language_models()
.build(),
);
let post_indexer = PostIndexer::new(&database, &bluesky);
let profile_classifier = ProfileClassifier::new(&database, &ai, &bluesky);
let feed_server = FeedServer::new(&database, &config);
let algos = Arc::new(
AlgosBuilder::new()
.add("nederlandskie", Nederlandskie::new(language_detector))
.build(),
);
let post_indexer = PostIndexer::new(database.clone(), bluesky.clone(), algos.clone());
let profile_classifier = ProfileClassifier::new(database.clone(), ai.clone(), bluesky.clone());
let feed_server = FeedServer::new(database.clone(), config.clone(), algos.clone());
tokio::try_join!(
post_indexer.start(),

View File

@ -3,7 +3,6 @@ use atrium_api::app::bsky::feed::describe_feed_generator::{
};
use axum::{extract::State, Json};
use crate::algos;
use crate::processes::feed_server::state::FeedServerState;
pub async fn describe_feed_generator(
@ -11,7 +10,9 @@ pub async fn describe_feed_generator(
) -> Json<FeedGeneratorDescription> {
Json(FeedGeneratorDescription {
did: state.config.service_did.clone(),
feeds: algos::iter_names()
feeds: state
.algos
.iter_names()
.map(|name| Feed {
uri: format!(
"at://{}/app.bsky.feed.generator/{}",

View File

@ -7,14 +7,15 @@ use axum::extract::{Query, State};
use axum::Json;
use chrono::{DateTime, TimeZone, Utc};
use crate::algos;
use crate::processes::feed_server::state::FeedServerState;
pub async fn get_feed_skeleton(
State(state): State<FeedServerState>,
query: Query<FeedSkeletonQuery>,
) -> Json<FeedSkeleton> {
let algo = algos::get_by_name(&query.feed)
let algo = state
.algos
.get_by_name(&query.feed)
.ok_or_else(|| anyhow!("Feed {} not found", query.feed))
.unwrap(); // TODO: handle error

View File

@ -1,23 +1,30 @@
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::Result;
use axum::routing::get;
use axum::{Router, Server};
use crate::algos::Algos;
use crate::config::Config;
use crate::services::Database;
use super::endpoints::{describe_feed_generator, did_json, get_feed_skeleton, root};
use super::state::FeedServerState;
pub struct FeedServer<'a> {
database: &'a Database,
config: &'a Config,
pub struct FeedServer {
database: Arc<Database>,
config: Arc<Config>,
algos: Arc<Algos>,
}
impl<'a> FeedServer<'a> {
pub fn new(database: &'a Database, config: &'a Config) -> Self {
Self { database, config }
impl FeedServer {
pub fn new(database: Arc<Database>, config: Arc<Config>, algos: Arc<Algos>) -> Self {
Self {
database,
config,
algos,
}
}
pub async fn serve(self) -> Result<()> {
@ -33,8 +40,9 @@ impl<'a> FeedServer<'a> {
get(get_feed_skeleton),
)
.with_state(FeedServerState {
database: self.database.clone(),
config: self.config.clone(),
database: self.database,
config: self.config,
algos: self.algos,
});
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View File

@ -1,8 +1,12 @@
use std::sync::Arc;
use crate::algos::Algos;
use crate::config::Config;
use crate::services::Database;
#[derive(Clone)]
pub struct FeedServerState {
pub database: Database,
pub config: Config,
pub database: Arc<Database>,
pub config: Arc<Config>,
pub algos: Arc<Algos>,
}

View File

@ -1,29 +1,36 @@
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use crate::algos;
use crate::algos::Algos;
use crate::services::bluesky::{Bluesky, Operation, OperationProcessor};
use crate::services::Database;
pub struct PostIndexer<'a> {
database: &'a Database,
bluesky: &'a Bluesky,
pub struct PostIndexer {
database: Arc<Database>,
bluesky: Arc<Bluesky>,
algos: Arc<Algos>,
}
impl<'a> PostIndexer<'a> {
pub fn new(database: &'a Database, bluesky: &'a Bluesky) -> Self {
Self { database, bluesky }
impl PostIndexer {
pub fn new(database: Arc<Database>, bluesky: Arc<Bluesky>, algos: Arc<Algos>) -> Self {
Self {
database,
bluesky,
algos,
}
}
}
impl<'a> PostIndexer<'a> {
impl PostIndexer {
pub async fn start(&self) -> Result<()> {
Ok(self.bluesky.subscribe_to_operations(self).await?)
}
}
#[async_trait]
impl<'a> OperationProcessor for PostIndexer<'a> {
impl OperationProcessor for PostIndexer {
async fn process_operation(&self, operation: &Operation) -> Result<()> {
match operation {
Operation::CreatePost {
@ -33,7 +40,11 @@ impl<'a> OperationProcessor for PostIndexer<'a> {
languages,
text,
} => {
if algos::iter_all().any(|a| a.should_index_post(author_did, languages, text)) {
if self
.algos
.iter_all()
.any(|a| a.should_index_post(author_did, languages, text))
{
println!("received insertable post from {author_did}: {text}");
self.database

View File

@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
@ -6,14 +7,14 @@ use crate::services::Bluesky;
use crate::services::Database;
use crate::services::AI;
pub struct ProfileClassifier<'a> {
database: &'a Database,
ai: &'a AI,
bluesky: &'a Bluesky,
pub struct ProfileClassifier {
database: Arc<Database>,
ai: Arc<AI>,
bluesky: Arc<Bluesky>,
}
impl<'a> ProfileClassifier<'a> {
pub fn new(database: &'a Database, ai: &'a AI, bluesky: &'a Bluesky) -> Self {
impl ProfileClassifier {
pub fn new(database: Arc<Database>, ai: Arc<AI>, bluesky: Arc<Bluesky>) -> Self {
Self {
database,
ai,

View File

@ -24,7 +24,6 @@ pub struct SubscriptionState {
cursor: i64,
}
#[derive(Clone)]
pub struct Database {
connection_pool: PgPool,
}