Compare commits
	
		
			1 Commits
		
	
	
		
			master
			...
			any-select
		
	
	| Author | SHA1 | Date | 
|---|---|---|
| 
							
							
								 | 
						4b8243ed58 | 
							
								
								
									
										79
									
								
								src/main.rs
								
								
								
								
							
							
						
						
									
										79
									
								
								src/main.rs
								
								
								
								
							| 
						 | 
				
			
			@ -13,10 +13,12 @@ use crate::{
 | 
			
		|||
    config::{FediverseConfig, Publisher},
 | 
			
		||||
    publish::MastodonPublisher,
 | 
			
		||||
    publish::MisskeyPublisher,
 | 
			
		||||
    selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector},
 | 
			
		||||
    selection::{telegram::get_chat_ref, AnySelector, SelectorExt, TelegramSelector},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use futures::{SinkExt, StreamExt, TryStreamExt, channel::mpsc::channel, future::Either, sink::unfold};
 | 
			
		||||
use futures::{
 | 
			
		||||
    channel::mpsc::channel, future::Either, sink::unfold, SinkExt, StreamExt, TryStreamExt,
 | 
			
		||||
};
 | 
			
		||||
use model::SampleModelExt;
 | 
			
		||||
 | 
			
		||||
mod config;
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +42,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
 | 
			
		|||
        }
 | 
			
		||||
    }?;
 | 
			
		||||
 | 
			
		||||
	let publisher = resolve_publisher(&mut cfg).await?;
 | 
			
		||||
    let publisher = resolve_publisher(&mut cfg).await?;
 | 
			
		||||
 | 
			
		||||
    let api = Arc::new(Api::new(cfg.bot_token.clone()));
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -66,43 +68,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
 | 
			
		|||
 | 
			
		||||
    let cfg_clone = cfg.clone();
 | 
			
		||||
 | 
			
		||||
    let mut model = TelegramSelector::new(
 | 
			
		||||
        api,
 | 
			
		||||
        chat,
 | 
			
		||||
        Box::pin(unfold((), move |_, chat_ref| {
 | 
			
		||||
            let mut cfg_clone = cfg_clone.clone();
 | 
			
		||||
            async move {
 | 
			
		||||
                if let ChatRef::Id(id) = &chat_ref {
 | 
			
		||||
                    cfg_clone.chat_ref = id.clone();
 | 
			
		||||
                    let _ = cfg_clone.save(CONFIG_PATH);
 | 
			
		||||
                }
 | 
			
		||||
                Ok::<_, Infallible>(())
 | 
			
		||||
            }
 | 
			
		||||
        })),
 | 
			
		||||
    )
 | 
			
		||||
    .filter(
 | 
			
		||||
        model::GPTSampleModel::new(
 | 
			
		||||
            cfg.python_path.clone(),
 | 
			
		||||
            cfg.gpt_code_path.clone(),
 | 
			
		||||
            vec![
 | 
			
		||||
                "generate_unconditional_samples.py".to_string(),
 | 
			
		||||
                "--model_name".to_string(),
 | 
			
		||||
                cfg.model_name.clone(),
 | 
			
		||||
                "--temperature".to_string(),
 | 
			
		||||
                cfg.temperature.clone(),
 | 
			
		||||
                "--top_k".to_string(),
 | 
			
		||||
                cfg.top_k.clone(),
 | 
			
		||||
                "--nsamples".to_string(),
 | 
			
		||||
                "1".to_string(),
 | 
			
		||||
            ],
 | 
			
		||||
    let mut model = AnySelector::new()
 | 
			
		||||
        .filter(
 | 
			
		||||
            model::GPTSampleModel::new(
 | 
			
		||||
                cfg.python_path.clone(),
 | 
			
		||||
                cfg.gpt_code_path.clone(),
 | 
			
		||||
                vec![
 | 
			
		||||
                    "generate_unconditional_samples.py".to_string(),
 | 
			
		||||
                    "--model_name".to_string(),
 | 
			
		||||
                    cfg.model_name.clone(),
 | 
			
		||||
                    "--temperature".to_string(),
 | 
			
		||||
                    cfg.temperature.clone(),
 | 
			
		||||
                    "--top_k".to_string(),
 | 
			
		||||
                    cfg.top_k.clone(),
 | 
			
		||||
                    "--nsamples".to_string(),
 | 
			
		||||
                    "1".to_string(),
 | 
			
		||||
                ],
 | 
			
		||||
            )
 | 
			
		||||
            .into_stream()
 | 
			
		||||
            .try_filter(|message| {
 | 
			
		||||
                let criteria = !message.is_empty() && message.chars().count() <= 4096;
 | 
			
		||||
                async move { criteria }
 | 
			
		||||
            }),
 | 
			
		||||
        )
 | 
			
		||||
        .into_stream()
 | 
			
		||||
        .try_filter(|message| {
 | 
			
		||||
            let criteria = !message.is_empty() && message.chars().count() <= 4096;
 | 
			
		||||
            async move { criteria }
 | 
			
		||||
        }),
 | 
			
		||||
    )
 | 
			
		||||
    .map_err(|e| Box::new(e) as Box<dyn Error>);
 | 
			
		||||
        .map_err(|e| Box::new(e) as Box<dyn Error>);
 | 
			
		||||
 | 
			
		||||
    tokio::spawn(async move {
 | 
			
		||||
        loop {
 | 
			
		||||
| 
						 | 
				
			
			@ -145,9 +134,11 @@ async fn resolve_publisher(
 | 
			
		|||
    config: &mut config::Config,
 | 
			
		||||
) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> {
 | 
			
		||||
    let publisher = match &config.publisher {
 | 
			
		||||
        config::Publisher::Misskey(cfg) => {
 | 
			
		||||
            Either::Left(MisskeyPublisher::new(&cfg.base_url, cfg.token.clone(), cfg.visibility)?)
 | 
			
		||||
        }
 | 
			
		||||
        config::Publisher::Misskey(cfg) => Either::Left(MisskeyPublisher::new(
 | 
			
		||||
            &cfg.base_url,
 | 
			
		||||
            cfg.token.clone(),
 | 
			
		||||
            cfg.visibility,
 | 
			
		||||
        )?),
 | 
			
		||||
        config::Publisher::Mastodon(cfg) => {
 | 
			
		||||
            let app = AppBuilder {
 | 
			
		||||
                client_name: "izzilis",
 | 
			
		||||
| 
						 | 
				
			
			@ -158,7 +149,7 @@ async fn resolve_publisher(
 | 
			
		|||
 | 
			
		||||
            let mut registration = Registration::new(cfg.base_url.clone());
 | 
			
		||||
            registration.register(app)?;
 | 
			
		||||
			let vis = cfg.visibility;
 | 
			
		||||
            let vis = cfg.visibility;
 | 
			
		||||
 | 
			
		||||
            let mastodon = if let Some(data) = cfg.token.clone() {
 | 
			
		||||
                Mastodon::from_data(data.clone())
 | 
			
		||||
| 
						 | 
				
			
			@ -174,7 +165,7 @@ async fn resolve_publisher(
 | 
			
		|||
                config.publisher = Publisher::Mastodon(FediverseConfig {
 | 
			
		||||
                    base_url: cfg.base_url.clone(),
 | 
			
		||||
                    token: Some(fedi.data.clone()),
 | 
			
		||||
					visibility: vis.clone(),
 | 
			
		||||
                    visibility: vis.clone(),
 | 
			
		||||
                });
 | 
			
		||||
                config.save(CONFIG_PATH)?;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
use std::convert::Infallible;
 | 
			
		||||
 | 
			
		||||
use futures::future::BoxFuture;
 | 
			
		||||
 | 
			
		||||
use super::Selector;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Copy, Clone)]
 | 
			
		||||
pub struct AnySelector;
 | 
			
		||||
 | 
			
		||||
impl Selector for AnySelector {
 | 
			
		||||
    type Error = Box<Infallible>;
 | 
			
		||||
    type Response = BoxFuture<'static, Result<bool, Self::Error>>;
 | 
			
		||||
 | 
			
		||||
    fn review(&self, message: String) -> Self::Response {
 | 
			
		||||
        Box::pin(async move { Ok(message.len() != 0) })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AnySelector {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self {}
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,29 +0,0 @@
 | 
			
		|||
use std::error::Error;
 | 
			
		||||
 | 
			
		||||
use async_std::io::stdin;
 | 
			
		||||
use futures::future::BoxFuture;
 | 
			
		||||
 | 
			
		||||
use super::Selector;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Copy, Clone)]
 | 
			
		||||
pub struct ConsoleSelector;
 | 
			
		||||
 | 
			
		||||
impl Selector for ConsoleSelector {
 | 
			
		||||
    type Error = Box<dyn Error>;
 | 
			
		||||
    type Response = BoxFuture<'static, Result<bool, Self::Error>>;
 | 
			
		||||
 | 
			
		||||
    fn review(&self, message: String) -> Self::Response {
 | 
			
		||||
        println!("{} (y/N) ", message);
 | 
			
		||||
        let stdin = stdin();
 | 
			
		||||
        Box::pin(async move {
 | 
			
		||||
            let mut buffer = String::new();
 | 
			
		||||
            stdin.read_line(&mut buffer).await?;
 | 
			
		||||
            Ok(
 | 
			
		||||
                match buffer.chars().next().unwrap_or('n').to_ascii_lowercase() {
 | 
			
		||||
                    'y' => true,
 | 
			
		||||
                    _ => false,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -2,9 +2,9 @@ use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
 | 
			
		|||
use std::fmt::Debug;
 | 
			
		||||
use thiserror::Error;
 | 
			
		||||
 | 
			
		||||
mod console;
 | 
			
		||||
mod any;
 | 
			
		||||
pub mod telegram;
 | 
			
		||||
pub use console::ConsoleSelector;
 | 
			
		||||
pub use any::AnySelector;
 | 
			
		||||
pub use telegram::TelegramSelector;
 | 
			
		||||
 | 
			
		||||
pub trait Selector {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue