Changed to hardcoded any selector
This commit is contained in:
parent
4fcb8a9771
commit
4b8243ed58
79
src/main.rs
79
src/main.rs
|
@ -13,10 +13,12 @@ use crate::{
|
||||||
config::{FediverseConfig, Publisher},
|
config::{FediverseConfig, Publisher},
|
||||||
publish::MastodonPublisher,
|
publish::MastodonPublisher,
|
||||||
publish::MisskeyPublisher,
|
publish::MisskeyPublisher,
|
||||||
selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector},
|
selection::{telegram::get_chat_ref, AnySelector, SelectorExt, TelegramSelector},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::{SinkExt, StreamExt, TryStreamExt, channel::mpsc::channel, future::Either, sink::unfold};
|
use futures::{
|
||||||
|
channel::mpsc::channel, future::Either, sink::unfold, SinkExt, StreamExt, TryStreamExt,
|
||||||
|
};
|
||||||
use model::SampleModelExt;
|
use model::SampleModelExt;
|
||||||
|
|
||||||
mod config;
|
mod config;
|
||||||
|
@ -40,7 +42,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
}
|
}
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
let publisher = resolve_publisher(&mut cfg).await?;
|
let publisher = resolve_publisher(&mut cfg).await?;
|
||||||
|
|
||||||
let api = Arc::new(Api::new(cfg.bot_token.clone()));
|
let api = Arc::new(Api::new(cfg.bot_token.clone()));
|
||||||
|
|
||||||
|
@ -66,43 +68,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
|
||||||
let cfg_clone = cfg.clone();
|
let cfg_clone = cfg.clone();
|
||||||
|
|
||||||
let mut model = TelegramSelector::new(
|
let mut model = AnySelector::new()
|
||||||
api,
|
.filter(
|
||||||
chat,
|
model::GPTSampleModel::new(
|
||||||
Box::pin(unfold((), move |_, chat_ref| {
|
cfg.python_path.clone(),
|
||||||
let mut cfg_clone = cfg_clone.clone();
|
cfg.gpt_code_path.clone(),
|
||||||
async move {
|
vec![
|
||||||
if let ChatRef::Id(id) = &chat_ref {
|
"generate_unconditional_samples.py".to_string(),
|
||||||
cfg_clone.chat_ref = id.clone();
|
"--model_name".to_string(),
|
||||||
let _ = cfg_clone.save(CONFIG_PATH);
|
cfg.model_name.clone(),
|
||||||
}
|
"--temperature".to_string(),
|
||||||
Ok::<_, Infallible>(())
|
cfg.temperature.clone(),
|
||||||
}
|
"--top_k".to_string(),
|
||||||
})),
|
cfg.top_k.clone(),
|
||||||
)
|
"--nsamples".to_string(),
|
||||||
.filter(
|
"1".to_string(),
|
||||||
model::GPTSampleModel::new(
|
],
|
||||||
cfg.python_path.clone(),
|
)
|
||||||
cfg.gpt_code_path.clone(),
|
.into_stream()
|
||||||
vec![
|
.try_filter(|message| {
|
||||||
"generate_unconditional_samples.py".to_string(),
|
let criteria = !message.is_empty() && message.chars().count() <= 4096;
|
||||||
"--model_name".to_string(),
|
async move { criteria }
|
||||||
cfg.model_name.clone(),
|
}),
|
||||||
"--temperature".to_string(),
|
|
||||||
cfg.temperature.clone(),
|
|
||||||
"--top_k".to_string(),
|
|
||||||
cfg.top_k.clone(),
|
|
||||||
"--nsamples".to_string(),
|
|
||||||
"1".to_string(),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
.into_stream()
|
.map_err(|e| Box::new(e) as Box<dyn Error>);
|
||||||
.try_filter(|message| {
|
|
||||||
let criteria = !message.is_empty() && message.chars().count() <= 4096;
|
|
||||||
async move { criteria }
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn Error>);
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
|
@ -145,9 +134,11 @@ async fn resolve_publisher(
|
||||||
config: &mut config::Config,
|
config: &mut config::Config,
|
||||||
) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> {
|
) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> {
|
||||||
let publisher = match &config.publisher {
|
let publisher = match &config.publisher {
|
||||||
config::Publisher::Misskey(cfg) => {
|
config::Publisher::Misskey(cfg) => Either::Left(MisskeyPublisher::new(
|
||||||
Either::Left(MisskeyPublisher::new(&cfg.base_url, cfg.token.clone(), cfg.visibility)?)
|
&cfg.base_url,
|
||||||
}
|
cfg.token.clone(),
|
||||||
|
cfg.visibility,
|
||||||
|
)?),
|
||||||
config::Publisher::Mastodon(cfg) => {
|
config::Publisher::Mastodon(cfg) => {
|
||||||
let app = AppBuilder {
|
let app = AppBuilder {
|
||||||
client_name: "izzilis",
|
client_name: "izzilis",
|
||||||
|
@ -158,7 +149,7 @@ async fn resolve_publisher(
|
||||||
|
|
||||||
let mut registration = Registration::new(cfg.base_url.clone());
|
let mut registration = Registration::new(cfg.base_url.clone());
|
||||||
registration.register(app)?;
|
registration.register(app)?;
|
||||||
let vis = cfg.visibility;
|
let vis = cfg.visibility;
|
||||||
|
|
||||||
let mastodon = if let Some(data) = cfg.token.clone() {
|
let mastodon = if let Some(data) = cfg.token.clone() {
|
||||||
Mastodon::from_data(data.clone())
|
Mastodon::from_data(data.clone())
|
||||||
|
@ -174,7 +165,7 @@ async fn resolve_publisher(
|
||||||
config.publisher = Publisher::Mastodon(FediverseConfig {
|
config.publisher = Publisher::Mastodon(FediverseConfig {
|
||||||
base_url: cfg.base_url.clone(),
|
base_url: cfg.base_url.clone(),
|
||||||
token: Some(fedi.data.clone()),
|
token: Some(fedi.data.clone()),
|
||||||
visibility: vis.clone(),
|
visibility: vis.clone(),
|
||||||
});
|
});
|
||||||
config.save(CONFIG_PATH)?;
|
config.save(CONFIG_PATH)?;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
use std::convert::Infallible;
|
||||||
|
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
|
||||||
|
use super::Selector;
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub struct AnySelector;
|
||||||
|
|
||||||
|
impl Selector for AnySelector {
|
||||||
|
type Error = Box<Infallible>;
|
||||||
|
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
|
||||||
|
|
||||||
|
fn review(&self, message: String) -> Self::Response {
|
||||||
|
Box::pin(async move { Ok(message.len() != 0) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnySelector {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,29 +0,0 @@
|
||||||
use std::error::Error;
|
|
||||||
|
|
||||||
use async_std::io::stdin;
|
|
||||||
use futures::future::BoxFuture;
|
|
||||||
|
|
||||||
use super::Selector;
|
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
|
||||||
pub struct ConsoleSelector;
|
|
||||||
|
|
||||||
impl Selector for ConsoleSelector {
|
|
||||||
type Error = Box<dyn Error>;
|
|
||||||
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
|
|
||||||
|
|
||||||
fn review(&self, message: String) -> Self::Response {
|
|
||||||
println!("{} (y/N) ", message);
|
|
||||||
let stdin = stdin();
|
|
||||||
Box::pin(async move {
|
|
||||||
let mut buffer = String::new();
|
|
||||||
stdin.read_line(&mut buffer).await?;
|
|
||||||
Ok(
|
|
||||||
match buffer.chars().next().unwrap_or('n').to_ascii_lowercase() {
|
|
||||||
'y' => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -2,9 +2,9 @@ use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
mod console;
|
mod any;
|
||||||
pub mod telegram;
|
pub mod telegram;
|
||||||
pub use console::ConsoleSelector;
|
pub use any::AnySelector;
|
||||||
pub use telegram::TelegramSelector;
|
pub use telegram::TelegramSelector;
|
||||||
|
|
||||||
pub trait Selector {
|
pub trait Selector {
|
||||||
|
|
Loading…
Reference in New Issue