Changed to hardcoded any selector
This commit is contained in:
parent
4fcb8a9771
commit
4b8243ed58
79
src/main.rs
79
src/main.rs
|
@ -13,10 +13,12 @@ use crate::{
|
|||
config::{FediverseConfig, Publisher},
|
||||
publish::MastodonPublisher,
|
||||
publish::MisskeyPublisher,
|
||||
selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector},
|
||||
selection::{telegram::get_chat_ref, AnySelector, SelectorExt, TelegramSelector},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt, TryStreamExt, channel::mpsc::channel, future::Either, sink::unfold};
|
||||
use futures::{
|
||||
channel::mpsc::channel, future::Either, sink::unfold, SinkExt, StreamExt, TryStreamExt,
|
||||
};
|
||||
use model::SampleModelExt;
|
||||
|
||||
mod config;
|
||||
|
@ -40,7 +42,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||
}
|
||||
}?;
|
||||
|
||||
let publisher = resolve_publisher(&mut cfg).await?;
|
||||
let publisher = resolve_publisher(&mut cfg).await?;
|
||||
|
||||
let api = Arc::new(Api::new(cfg.bot_token.clone()));
|
||||
|
||||
|
@ -66,43 +68,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||
|
||||
let cfg_clone = cfg.clone();
|
||||
|
||||
let mut model = TelegramSelector::new(
|
||||
api,
|
||||
chat,
|
||||
Box::pin(unfold((), move |_, chat_ref| {
|
||||
let mut cfg_clone = cfg_clone.clone();
|
||||
async move {
|
||||
if let ChatRef::Id(id) = &chat_ref {
|
||||
cfg_clone.chat_ref = id.clone();
|
||||
let _ = cfg_clone.save(CONFIG_PATH);
|
||||
}
|
||||
Ok::<_, Infallible>(())
|
||||
}
|
||||
})),
|
||||
)
|
||||
.filter(
|
||||
model::GPTSampleModel::new(
|
||||
cfg.python_path.clone(),
|
||||
cfg.gpt_code_path.clone(),
|
||||
vec![
|
||||
"generate_unconditional_samples.py".to_string(),
|
||||
"--model_name".to_string(),
|
||||
cfg.model_name.clone(),
|
||||
"--temperature".to_string(),
|
||||
cfg.temperature.clone(),
|
||||
"--top_k".to_string(),
|
||||
cfg.top_k.clone(),
|
||||
"--nsamples".to_string(),
|
||||
"1".to_string(),
|
||||
],
|
||||
let mut model = AnySelector::new()
|
||||
.filter(
|
||||
model::GPTSampleModel::new(
|
||||
cfg.python_path.clone(),
|
||||
cfg.gpt_code_path.clone(),
|
||||
vec![
|
||||
"generate_unconditional_samples.py".to_string(),
|
||||
"--model_name".to_string(),
|
||||
cfg.model_name.clone(),
|
||||
"--temperature".to_string(),
|
||||
cfg.temperature.clone(),
|
||||
"--top_k".to_string(),
|
||||
cfg.top_k.clone(),
|
||||
"--nsamples".to_string(),
|
||||
"1".to_string(),
|
||||
],
|
||||
)
|
||||
.into_stream()
|
||||
.try_filter(|message| {
|
||||
let criteria = !message.is_empty() && message.chars().count() <= 4096;
|
||||
async move { criteria }
|
||||
}),
|
||||
)
|
||||
.into_stream()
|
||||
.try_filter(|message| {
|
||||
let criteria = !message.is_empty() && message.chars().count() <= 4096;
|
||||
async move { criteria }
|
||||
}),
|
||||
)
|
||||
.map_err(|e| Box::new(e) as Box<dyn Error>);
|
||||
.map_err(|e| Box::new(e) as Box<dyn Error>);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
|
@ -145,9 +134,11 @@ async fn resolve_publisher(
|
|||
config: &mut config::Config,
|
||||
) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> {
|
||||
let publisher = match &config.publisher {
|
||||
config::Publisher::Misskey(cfg) => {
|
||||
Either::Left(MisskeyPublisher::new(&cfg.base_url, cfg.token.clone(), cfg.visibility)?)
|
||||
}
|
||||
config::Publisher::Misskey(cfg) => Either::Left(MisskeyPublisher::new(
|
||||
&cfg.base_url,
|
||||
cfg.token.clone(),
|
||||
cfg.visibility,
|
||||
)?),
|
||||
config::Publisher::Mastodon(cfg) => {
|
||||
let app = AppBuilder {
|
||||
client_name: "izzilis",
|
||||
|
@ -158,7 +149,7 @@ async fn resolve_publisher(
|
|||
|
||||
let mut registration = Registration::new(cfg.base_url.clone());
|
||||
registration.register(app)?;
|
||||
let vis = cfg.visibility;
|
||||
let vis = cfg.visibility;
|
||||
|
||||
let mastodon = if let Some(data) = cfg.token.clone() {
|
||||
Mastodon::from_data(data.clone())
|
||||
|
@ -174,7 +165,7 @@ async fn resolve_publisher(
|
|||
config.publisher = Publisher::Mastodon(FediverseConfig {
|
||||
base_url: cfg.base_url.clone(),
|
||||
token: Some(fedi.data.clone()),
|
||||
visibility: vis.clone(),
|
||||
visibility: vis.clone(),
|
||||
});
|
||||
config.save(CONFIG_PATH)?;
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
use std::convert::Infallible;
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
|
||||
use super::Selector;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct AnySelector;
|
||||
|
||||
impl Selector for AnySelector {
|
||||
type Error = Box<Infallible>;
|
||||
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
|
||||
|
||||
fn review(&self, message: String) -> Self::Response {
|
||||
Box::pin(async move { Ok(message.len() != 0) })
|
||||
}
|
||||
}
|
||||
|
||||
impl AnySelector {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
use std::error::Error;
|
||||
|
||||
use async_std::io::stdin;
|
||||
use futures::future::BoxFuture;
|
||||
|
||||
use super::Selector;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct ConsoleSelector;
|
||||
|
||||
impl Selector for ConsoleSelector {
|
||||
type Error = Box<dyn Error>;
|
||||
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
|
||||
|
||||
fn review(&self, message: String) -> Self::Response {
|
||||
println!("{} (y/N) ", message);
|
||||
let stdin = stdin();
|
||||
Box::pin(async move {
|
||||
let mut buffer = String::new();
|
||||
stdin.read_line(&mut buffer).await?;
|
||||
Ok(
|
||||
match buffer.chars().next().unwrap_or('n').to_ascii_lowercase() {
|
||||
'y' => true,
|
||||
_ => false,
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,9 +2,9 @@ use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
|
|||
use std::fmt::Debug;
|
||||
use thiserror::Error;
|
||||
|
||||
mod console;
|
||||
mod any;
|
||||
pub mod telegram;
|
||||
pub use console::ConsoleSelector;
|
||||
pub use any::AnySelector;
|
||||
pub use telegram::TelegramSelector;
|
||||
|
||||
pub trait Selector {
|
||||
|
|
Loading…
Reference in New Issue