izzilis/src/main.rs

154 lines
4.2 KiB
Rust
Raw Normal View History

2021-07-06 03:13:51 +01:00
use std::{convert::Infallible, error::Error, sync::Arc, time::Duration};
2021-06-25 20:48:06 +01:00
2021-07-06 03:13:51 +01:00
use async_std::io::stdin;
use futures_timer::Delay;
use mammut::{
apps::{AppBuilder, Scopes},
2021-07-06 03:41:24 +01:00
Mastodon, Registration,
2021-07-06 03:13:51 +01:00
};
use rand::Rng;
2021-07-06 02:39:09 +01:00
use telegram_bot::{Api, ChatRef};
2021-07-06 03:13:51 +01:00
use crate::{
publish::MastodonPublisher,
selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector},
};
2021-07-06 03:13:51 +01:00
use futures::{channel::mpsc::channel, sink::unfold, SinkExt, StreamExt, TryStreamExt};
use model::SampleModelExt;
mod config;
mod model;
2021-07-06 03:13:51 +01:00
mod publish;
mod selection;
const CONFIG_PATH: &str = "bot_config.json";
2021-07-06 02:39:09 +01:00
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let mut cfg = match config::Config::from(CONFIG_PATH) {
Ok(cfg) => Ok(cfg),
Err(_) => {
println!(
"Failed reading config at [{}], writing default",
CONFIG_PATH
);
config::Config::default().save(CONFIG_PATH)?;
Err(anyhow::anyhow!("Invalid configuration"))
}
}?;
2021-07-06 03:13:51 +01:00
let app = AppBuilder {
client_name: "izzilis",
redirect_uris: "urn:ietf:wg:oauth:2.0:oob",
scopes: Scopes::Write,
website: None,
};
let mut registration = Registration::new(cfg.fediverse_base_url.clone());
registration.register(app)?;
2021-07-06 03:41:24 +01:00
let mastodon = if let Some(data) = &cfg.fediverse_token {
Mastodon::from_data(data.clone())
} else {
2021-07-06 03:13:51 +01:00
let url = registration.authorise()?;
println!("{}", url);
let mut buffer = String::new();
stdin().read_line(&mut buffer).await?;
2021-07-06 03:52:52 +01:00
let fedi = registration.create_access_token(buffer)?;
2021-07-06 03:41:24 +01:00
2021-07-06 03:52:52 +01:00
cfg.fediverse_token = Some(fedi.data.clone());
2021-07-06 03:13:51 +01:00
cfg.save(CONFIG_PATH)?;
2021-07-06 03:52:52 +01:00
fedi
2021-07-06 03:13:51 +01:00
};
let publisher = MastodonPublisher::new(mastodon);
2021-07-06 03:14:37 +01:00
let api = Arc::new(Api::new(cfg.bot_token.clone()));
2021-07-06 02:39:09 +01:00
let chat = get_chat_ref(api.clone(), None)
.await
.expect("could not get chat ref");
let chat_ref = chat.lock().await.clone();
if let ChatRef::Id(id) = chat_ref {
cfg.chat_ref = id;
cfg.save(CONFIG_PATH)?;
}
2021-07-06 03:13:51 +01:00
let (sender, receiver) = channel(cfg.post_buffer as usize);
2021-07-06 02:39:09 +01:00
let cfg_clone = cfg.clone();
2021-07-06 03:13:51 +01:00
let mut model = TelegramSelector::new(
2021-07-06 02:39:09 +01:00
api,
chat,
Box::pin(unfold((), move |_, chat_ref| {
let mut cfg_clone = cfg_clone.clone();
async move {
if let ChatRef::Id(id) = &chat_ref {
cfg_clone.chat_ref = id.clone();
let _ = cfg_clone.save(CONFIG_PATH);
}
Ok::<_, Infallible>(())
}
})),
)
2021-07-06 03:13:51 +01:00
.filter(
model::GPTSampleModel::new(
cfg.python_path.clone(),
cfg.gpt_code_path.clone(),
vec![
"generate_unconditional_samples.py".to_string(),
"--model_name".to_string(),
cfg.model_name.clone(),
"--temperature".to_string(),
cfg.temperature.clone(),
"--top_k".to_string(),
cfg.top_k.clone(),
"--nsamples".to_string(),
"1".to_string(),
],
)
2021-07-06 03:53:24 +01:00
.into_stream()
.try_filter(|message| {
let not_empty = !message.is_empty();
async move { not_empty }
}),
2021-07-06 03:13:51 +01:00
)
.map_err(|e| Box::new(e) as Box<dyn Error>);
2021-07-06 03:13:51 +01:00
tokio::spawn(async move {
sender
.sink_map_err(|e| Box::new(e) as Box<dyn Error>)
.send_all(&mut model)
.await
.expect("Broken buffer");
});
publisher
.sink_map_err(|e| Box::new(e) as Box<dyn Error>)
.send_all(
&mut receiver
.then(|item| {
let interval_seconds = cfg.interval_seconds.clone();
Box::pin(async move {
2021-07-06 03:42:16 +01:00
Delay::new(Duration::from_secs(
2021-07-06 03:13:51 +01:00
rand::thread_rng()
.gen_range(interval_seconds.min..interval_seconds.max),
))
.await;
item
})
})
.map(Ok),
)
.await?;
2021-07-06 02:39:09 +01:00
return Ok(());
}