diff --git a/src/main.rs b/src/main.rs index 6107038..7f4d21b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,10 +13,12 @@ use crate::{ config::{FediverseConfig, Publisher}, publish::MastodonPublisher, publish::MisskeyPublisher, - selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector}, + selection::{telegram::get_chat_ref, AnySelector, SelectorExt, TelegramSelector}, }; -use futures::{SinkExt, StreamExt, TryStreamExt, channel::mpsc::channel, future::Either, sink::unfold}; +use futures::{ + channel::mpsc::channel, future::Either, sink::unfold, SinkExt, StreamExt, TryStreamExt, +}; use model::SampleModelExt; mod config; @@ -40,7 +42,7 @@ async fn main() -> Result<(), Box> { } }?; - let publisher = resolve_publisher(&mut cfg).await?; + let publisher = resolve_publisher(&mut cfg).await?; let api = Arc::new(Api::new(cfg.bot_token.clone())); @@ -66,43 +68,30 @@ async fn main() -> Result<(), Box> { let cfg_clone = cfg.clone(); - let mut model = TelegramSelector::new( - api, - chat, - Box::pin(unfold((), move |_, chat_ref| { - let mut cfg_clone = cfg_clone.clone(); - async move { - if let ChatRef::Id(id) = &chat_ref { - cfg_clone.chat_ref = id.clone(); - let _ = cfg_clone.save(CONFIG_PATH); - } - Ok::<_, Infallible>(()) - } - })), - ) - .filter( - model::GPTSampleModel::new( - cfg.python_path.clone(), - cfg.gpt_code_path.clone(), - vec![ - "generate_unconditional_samples.py".to_string(), - "--model_name".to_string(), - cfg.model_name.clone(), - "--temperature".to_string(), - cfg.temperature.clone(), - "--top_k".to_string(), - cfg.top_k.clone(), - "--nsamples".to_string(), - "1".to_string(), - ], + let mut model = AnySelector::new() + .filter( + model::GPTSampleModel::new( + cfg.python_path.clone(), + cfg.gpt_code_path.clone(), + vec![ + "generate_unconditional_samples.py".to_string(), + "--model_name".to_string(), + cfg.model_name.clone(), + "--temperature".to_string(), + cfg.temperature.clone(), + "--top_k".to_string(), + cfg.top_k.clone(), + "--nsamples".to_string(), + "1".to_string(), + ], + ) + .into_stream() + .try_filter(|message| { + let criteria = !message.is_empty() && message.chars().count() <= 4096; + async move { criteria } + }), ) - .into_stream() - .try_filter(|message| { - let criteria = !message.is_empty() && message.chars().count() <= 4096; - async move { criteria } - }), - ) - .map_err(|e| Box::new(e) as Box); + .map_err(|e| Box::new(e) as Box); tokio::spawn(async move { loop { @@ -145,9 +134,11 @@ async fn resolve_publisher( config: &mut config::Config, ) -> Result, Box> { let publisher = match &config.publisher { - config::Publisher::Misskey(cfg) => { - Either::Left(MisskeyPublisher::new(&cfg.base_url, cfg.token.clone(), cfg.visibility)?) - } + config::Publisher::Misskey(cfg) => Either::Left(MisskeyPublisher::new( + &cfg.base_url, + cfg.token.clone(), + cfg.visibility, + )?), config::Publisher::Mastodon(cfg) => { let app = AppBuilder { client_name: "izzilis", @@ -158,7 +149,7 @@ async fn resolve_publisher( let mut registration = Registration::new(cfg.base_url.clone()); registration.register(app)?; - let vis = cfg.visibility; + let vis = cfg.visibility; let mastodon = if let Some(data) = cfg.token.clone() { Mastodon::from_data(data.clone()) @@ -174,7 +165,7 @@ async fn resolve_publisher( config.publisher = Publisher::Mastodon(FediverseConfig { base_url: cfg.base_url.clone(), token: Some(fedi.data.clone()), - visibility: vis.clone(), + visibility: vis.clone(), }); config.save(CONFIG_PATH)?; diff --git a/src/selection/any.rs b/src/selection/any.rs new file mode 100644 index 0000000..6bf9002 --- /dev/null +++ b/src/selection/any.rs @@ -0,0 +1,23 @@ +use std::convert::Infallible; + +use futures::future::BoxFuture; + +use super::Selector; + +#[derive(Debug, Copy, Clone)] +pub struct AnySelector; + +impl Selector for AnySelector { + type Error = Box; + type Response = BoxFuture<'static, Result>; + + fn review(&self, message: String) -> Self::Response { + Box::pin(async move { Ok(message.len() != 0) }) + } +} + +impl AnySelector { + pub fn new() -> Self { + Self {} + } +} diff --git a/src/selection/console.rs b/src/selection/console.rs deleted file mode 100644 index 522c609..0000000 --- a/src/selection/console.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::error::Error; - -use async_std::io::stdin; -use futures::future::BoxFuture; - -use super::Selector; - -#[derive(Debug, Copy, Clone)] -pub struct ConsoleSelector; - -impl Selector for ConsoleSelector { - type Error = Box; - type Response = BoxFuture<'static, Result>; - - fn review(&self, message: String) -> Self::Response { - println!("{} (y/N) ", message); - let stdin = stdin(); - Box::pin(async move { - let mut buffer = String::new(); - stdin.read_line(&mut buffer).await?; - Ok( - match buffer.chars().next().unwrap_or('n').to_ascii_lowercase() { - 'y' => true, - _ => false, - }, - ) - }) - } -} diff --git a/src/selection/mod.rs b/src/selection/mod.rs index ef81e09..78a5a42 100644 --- a/src/selection/mod.rs +++ b/src/selection/mod.rs @@ -2,9 +2,9 @@ use futures::{stream::BoxStream, Future, Stream, TryStreamExt}; use std::fmt::Debug; use thiserror::Error; -mod console; +mod any; pub mod telegram; -pub use console::ConsoleSelector; +pub use any::AnySelector; pub use telegram::TelegramSelector; pub trait Selector {