use std::{convert::Infallible, error::Error, sync::Arc, time::Duration}; use async_std::io::stdin; use futures_timer::Delay; use mammut::{ apps::{AppBuilder, Scopes}, Mastodon, Registration, }; use rand::Rng; use telegram_bot::{Api, ChatId, ChatRef, ToChatRef}; use crate::{ publish::MastodonPublisher, selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector}, }; use futures::{channel::mpsc::channel, sink::unfold, SinkExt, StreamExt, TryStreamExt}; use model::SampleModelExt; mod config; mod model; mod publish; mod selection; const CONFIG_PATH: &str = "bot_config.json"; #[tokio::main] async fn main() -> Result<(), Box> { let mut cfg = match config::Config::from(CONFIG_PATH) { Ok(cfg) => Ok(cfg), Err(_) => { println!( "Failed reading config at [{}], writing default", CONFIG_PATH ); config::Config::default().save(CONFIG_PATH)?; Err(anyhow::anyhow!("Invalid configuration")) } }?; let app = AppBuilder { client_name: "izzilis", redirect_uris: "urn:ietf:wg:oauth:2.0:oob", scopes: Scopes::Write, website: None, }; let mut registration = Registration::new(cfg.fediverse_base_url.clone()); registration.register(app)?; let mastodon = if let Some(data) = &cfg.fediverse_token { Mastodon::from_data(data.clone()) } else { let url = registration.authorise()?; println!("{}", url); let mut buffer = String::new(); stdin().read_line(&mut buffer).await?; let fedi = registration.create_access_token(buffer)?; cfg.fediverse_token = Some(fedi.data.clone()); cfg.save(CONFIG_PATH)?; fedi }; let publisher = MastodonPublisher::new(mastodon); let api = Arc::new(Api::new(cfg.bot_token.clone())); let chat = get_chat_ref( api.clone(), if ChatId::new(0) == cfg.chat_ref { None } else { Some(cfg.chat_ref.clone().to_chat_ref()) }, ) .await .expect("could not get chat ref"); let chat_ref = chat.lock().await.clone(); if let ChatRef::Id(id) = chat_ref { cfg.chat_ref = id; cfg.save(CONFIG_PATH)?; } let (sender, receiver) = channel(cfg.post_buffer as usize); let cfg_clone = cfg.clone(); let mut model = TelegramSelector::new( api, chat, Box::pin(unfold((), move |_, chat_ref| { let mut cfg_clone = cfg_clone.clone(); async move { if let ChatRef::Id(id) = &chat_ref { cfg_clone.chat_ref = id.clone(); let _ = cfg_clone.save(CONFIG_PATH); } Ok::<_, Infallible>(()) } })), ) .filter( model::GPTSampleModel::new( cfg.python_path.clone(), cfg.gpt_code_path.clone(), vec![ "generate_unconditional_samples.py".to_string(), "--model_name".to_string(), cfg.model_name.clone(), "--temperature".to_string(), cfg.temperature.clone(), "--top_k".to_string(), cfg.top_k.clone(), "--nsamples".to_string(), "1".to_string(), ], ) .into_stream() .try_filter(|message| { let not_empty = !message.is_empty(); async move { not_empty } }), ) .map_err(|e| Box::new(e) as Box); tokio::spawn(async move { sender .sink_map_err(|e| Box::new(e) as Box) .send_all(&mut model) .await .expect("Broken buffer"); }); publisher .sink_map_err(|e| Box::new(e) as Box) .send_all( &mut receiver .then(|item| { let interval_seconds = cfg.interval_seconds.clone(); Box::pin(async move { Delay::new(Duration::from_secs( rand::thread_rng() .gen_range(interval_seconds.min..=interval_seconds.max), )) .await; item }) }) .map(Ok), ) .await?; return Ok(()); }