use std::{convert::Infallible, error::Error, sync::Arc}; use telegram_bot::{Api, ChatRef}; use crate::selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector}; use futures::{sink::unfold, StreamExt}; use model::{EmptyModel, SampleModelExt}; mod config; mod model; mod selection; const CONFIG_PATH: &str = "bot_config.json"; #[tokio::main] async fn main() -> Result<(), Box> { let mut cfg = match config::Config::from(CONFIG_PATH) { Ok(cfg) => Ok(cfg), Err(_) => { println!( "Failed reading config at [{}], writing default", CONFIG_PATH ); config::Config::default().save(CONFIG_PATH)?; Err(anyhow::anyhow!("Invalid configuration")) } }?; // let mut gpt_model = ConsoleSelector.filter( // model::GPTSampleModel::new( // cfg.python_path(), // cfg.gpt_code_path(), // vec![ // "generate_unconditional_samples.py".to_string(), // "--model_name".to_string(), // cfg.model_name(), // "--temperature".to_string(), // cfg.temperature(), // "--top_k".to_string(), // cfg.top_k(), // "--nsamples".to_string(), // "1".to_string(), // ], // ) // .into_stream() // .take(10), // ); let api = Arc::new(Api::new( std::env::var("TELEGRAM_BOT_KEY").expect("bot key not present"), )); let chat = get_chat_ref(api.clone(), None) .await .expect("could not get chat ref"); let chat_ref = chat.lock().await.clone(); if let ChatRef::Id(id) = chat_ref { cfg.chat_ref = id; cfg.save(CONFIG_PATH)?; } let cfg_clone = cfg.clone(); let mut gpt_model = TelegramSelector::new( api, chat, Box::pin(unfold((), move |_, chat_ref| { let mut cfg_clone = cfg_clone.clone(); async move { if let ChatRef::Id(id) = &chat_ref { cfg_clone.chat_ref = id.clone(); let _ = cfg_clone.save(CONFIG_PATH); } Ok::<_, Infallible>(()) } })), ) .filter(EmptyModel.into_stream()) .take(5); while let Some(Ok(sample)) = gpt_model.next().await { println!("{}", sample); } return Ok(()); }