2021-07-06 02:39:09 +01:00
|
|
|
use std::{convert::Infallible, error::Error, sync::Arc};
|
2021-06-25 20:48:06 +01:00
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
use telegram_bot::{Api, ChatRef};
|
2021-06-25 18:59:46 +01:00
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
use crate::selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector};
|
2021-06-28 19:57:06 +01:00
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
use futures::{sink::unfold, StreamExt};
|
2021-07-06 00:54:08 +01:00
|
|
|
use model::{EmptyModel, SampleModelExt};
|
2021-07-04 01:21:54 +01:00
|
|
|
|
2021-06-25 18:59:46 +01:00
|
|
|
mod config;
|
|
|
|
mod model;
|
2021-06-28 19:57:06 +01:00
|
|
|
mod selection;
|
2021-06-25 18:59:46 +01:00
|
|
|
|
|
|
|
const CONFIG_PATH: &str = "bot_config.json";
|
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
#[tokio::main]
|
|
|
|
async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
|
let mut cfg = match config::Config::from(CONFIG_PATH) {
|
|
|
|
Ok(cfg) => Ok(cfg),
|
|
|
|
Err(_) => {
|
|
|
|
println!(
|
|
|
|
"Failed reading config at [{}], writing default",
|
|
|
|
CONFIG_PATH
|
|
|
|
);
|
|
|
|
config::Config::default().save(CONFIG_PATH)?;
|
|
|
|
Err(anyhow::anyhow!("Invalid configuration"))
|
|
|
|
}
|
|
|
|
}?;
|
2021-06-25 18:59:46 +01:00
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
// let mut gpt_model = ConsoleSelector.filter(
|
|
|
|
// model::GPTSampleModel::new(
|
|
|
|
// cfg.python_path(),
|
|
|
|
// cfg.gpt_code_path(),
|
|
|
|
// vec![
|
|
|
|
// "generate_unconditional_samples.py".to_string(),
|
|
|
|
// "--model_name".to_string(),
|
|
|
|
// cfg.model_name(),
|
|
|
|
// "--temperature".to_string(),
|
|
|
|
// cfg.temperature(),
|
|
|
|
// "--top_k".to_string(),
|
|
|
|
// cfg.top_k(),
|
|
|
|
// "--nsamples".to_string(),
|
|
|
|
// "1".to_string(),
|
|
|
|
// ],
|
|
|
|
// )
|
|
|
|
// .into_stream()
|
|
|
|
// .take(10),
|
|
|
|
// );
|
2021-06-25 18:59:46 +01:00
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
let api = Arc::new(Api::new(
|
|
|
|
std::env::var("TELEGRAM_BOT_KEY").expect("bot key not present"),
|
|
|
|
));
|
|
|
|
|
|
|
|
let chat = get_chat_ref(api.clone(), None)
|
|
|
|
.await
|
|
|
|
.expect("could not get chat ref");
|
|
|
|
|
|
|
|
let chat_ref = chat.lock().await.clone();
|
|
|
|
|
|
|
|
if let ChatRef::Id(id) = chat_ref {
|
|
|
|
cfg.chat_ref = id;
|
|
|
|
cfg.save(CONFIG_PATH)?;
|
|
|
|
}
|
|
|
|
|
|
|
|
let cfg_clone = cfg.clone();
|
|
|
|
let mut gpt_model = TelegramSelector::new(
|
|
|
|
api,
|
|
|
|
chat,
|
|
|
|
Box::pin(unfold((), move |_, chat_ref| {
|
|
|
|
let mut cfg_clone = cfg_clone.clone();
|
|
|
|
async move {
|
|
|
|
if let ChatRef::Id(id) = &chat_ref {
|
|
|
|
cfg_clone.chat_ref = id.clone();
|
|
|
|
let _ = cfg_clone.save(CONFIG_PATH);
|
|
|
|
}
|
|
|
|
Ok::<_, Infallible>(())
|
|
|
|
}
|
|
|
|
})),
|
|
|
|
)
|
|
|
|
.filter(EmptyModel.into_stream())
|
|
|
|
.take(5);
|
2021-07-05 01:13:20 +01:00
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
while let Some(Ok(sample)) = gpt_model.next().await {
|
|
|
|
println!("{}", sample);
|
|
|
|
}
|
2021-07-05 01:13:20 +01:00
|
|
|
|
2021-07-06 02:39:09 +01:00
|
|
|
return Ok(());
|
2021-06-25 18:59:46 +01:00
|
|
|
}
|