Compare commits

...

1 Commits

Author SHA1 Message Date
puffaboo 4b8243ed58 Changed to hardcoded any selector 2022-06-01 18:31:14 +01:00
4 changed files with 60 additions and 75 deletions

View File

@ -13,10 +13,12 @@ use crate::{
config::{FediverseConfig, Publisher}, config::{FediverseConfig, Publisher},
publish::MastodonPublisher, publish::MastodonPublisher,
publish::MisskeyPublisher, publish::MisskeyPublisher,
selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector}, selection::{telegram::get_chat_ref, AnySelector, SelectorExt, TelegramSelector},
}; };
use futures::{SinkExt, StreamExt, TryStreamExt, channel::mpsc::channel, future::Either, sink::unfold}; use futures::{
channel::mpsc::channel, future::Either, sink::unfold, SinkExt, StreamExt, TryStreamExt,
};
use model::SampleModelExt; use model::SampleModelExt;
mod config; mod config;
@ -40,7 +42,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
} }
}?; }?;
let publisher = resolve_publisher(&mut cfg).await?; let publisher = resolve_publisher(&mut cfg).await?;
let api = Arc::new(Api::new(cfg.bot_token.clone())); let api = Arc::new(Api::new(cfg.bot_token.clone()));
@ -66,43 +68,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
let cfg_clone = cfg.clone(); let cfg_clone = cfg.clone();
let mut model = TelegramSelector::new( let mut model = AnySelector::new()
api, .filter(
chat, model::GPTSampleModel::new(
Box::pin(unfold((), move |_, chat_ref| { cfg.python_path.clone(),
let mut cfg_clone = cfg_clone.clone(); cfg.gpt_code_path.clone(),
async move { vec![
if let ChatRef::Id(id) = &chat_ref { "generate_unconditional_samples.py".to_string(),
cfg_clone.chat_ref = id.clone(); "--model_name".to_string(),
let _ = cfg_clone.save(CONFIG_PATH); cfg.model_name.clone(),
} "--temperature".to_string(),
Ok::<_, Infallible>(()) cfg.temperature.clone(),
} "--top_k".to_string(),
})), cfg.top_k.clone(),
) "--nsamples".to_string(),
.filter( "1".to_string(),
model::GPTSampleModel::new( ],
cfg.python_path.clone(), )
cfg.gpt_code_path.clone(), .into_stream()
vec![ .try_filter(|message| {
"generate_unconditional_samples.py".to_string(), let criteria = !message.is_empty() && message.chars().count() <= 4096;
"--model_name".to_string(), async move { criteria }
cfg.model_name.clone(), }),
"--temperature".to_string(),
cfg.temperature.clone(),
"--top_k".to_string(),
cfg.top_k.clone(),
"--nsamples".to_string(),
"1".to_string(),
],
) )
.into_stream() .map_err(|e| Box::new(e) as Box<dyn Error>);
.try_filter(|message| {
let criteria = !message.is_empty() && message.chars().count() <= 4096;
async move { criteria }
}),
)
.map_err(|e| Box::new(e) as Box<dyn Error>);
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@ -145,9 +134,11 @@ async fn resolve_publisher(
config: &mut config::Config, config: &mut config::Config,
) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> { ) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> {
let publisher = match &config.publisher { let publisher = match &config.publisher {
config::Publisher::Misskey(cfg) => { config::Publisher::Misskey(cfg) => Either::Left(MisskeyPublisher::new(
Either::Left(MisskeyPublisher::new(&cfg.base_url, cfg.token.clone(), cfg.visibility)?) &cfg.base_url,
} cfg.token.clone(),
cfg.visibility,
)?),
config::Publisher::Mastodon(cfg) => { config::Publisher::Mastodon(cfg) => {
let app = AppBuilder { let app = AppBuilder {
client_name: "izzilis", client_name: "izzilis",
@ -158,7 +149,7 @@ async fn resolve_publisher(
let mut registration = Registration::new(cfg.base_url.clone()); let mut registration = Registration::new(cfg.base_url.clone());
registration.register(app)?; registration.register(app)?;
let vis = cfg.visibility; let vis = cfg.visibility;
let mastodon = if let Some(data) = cfg.token.clone() { let mastodon = if let Some(data) = cfg.token.clone() {
Mastodon::from_data(data.clone()) Mastodon::from_data(data.clone())
@ -174,7 +165,7 @@ async fn resolve_publisher(
config.publisher = Publisher::Mastodon(FediverseConfig { config.publisher = Publisher::Mastodon(FediverseConfig {
base_url: cfg.base_url.clone(), base_url: cfg.base_url.clone(),
token: Some(fedi.data.clone()), token: Some(fedi.data.clone()),
visibility: vis.clone(), visibility: vis.clone(),
}); });
config.save(CONFIG_PATH)?; config.save(CONFIG_PATH)?;

23
src/selection/any.rs Normal file
View File

@ -0,0 +1,23 @@
use std::convert::Infallible;
use futures::future::BoxFuture;
use super::Selector;
#[derive(Debug, Copy, Clone)]
pub struct AnySelector;
impl Selector for AnySelector {
type Error = Box<Infallible>;
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
fn review(&self, message: String) -> Self::Response {
Box::pin(async move { Ok(message.len() != 0) })
}
}
impl AnySelector {
pub fn new() -> Self {
Self {}
}
}

View File

@ -1,29 +0,0 @@
use std::error::Error;
use async_std::io::stdin;
use futures::future::BoxFuture;
use super::Selector;
#[derive(Debug, Copy, Clone)]
pub struct ConsoleSelector;
impl Selector for ConsoleSelector {
type Error = Box<dyn Error>;
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
fn review(&self, message: String) -> Self::Response {
println!("{} (y/N) ", message);
let stdin = stdin();
Box::pin(async move {
let mut buffer = String::new();
stdin.read_line(&mut buffer).await?;
Ok(
match buffer.chars().next().unwrap_or('n').to_ascii_lowercase() {
'y' => true,
_ => false,
},
)
})
}
}

View File

@ -2,9 +2,9 @@ use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
use std::fmt::Debug; use std::fmt::Debug;
use thiserror::Error; use thiserror::Error;
mod console; mod any;
pub mod telegram; pub mod telegram;
pub use console::ConsoleSelector; pub use any::AnySelector;
pub use telegram::TelegramSelector; pub use telegram::TelegramSelector;
pub trait Selector { pub trait Selector {