Compare commits
No commits in common. "any-selector" and "master" have entirely different histories.
any-select
...
master
79
src/main.rs
79
src/main.rs
|
@ -13,12 +13,10 @@ use crate::{
|
||||||
config::{FediverseConfig, Publisher},
|
config::{FediverseConfig, Publisher},
|
||||||
publish::MastodonPublisher,
|
publish::MastodonPublisher,
|
||||||
publish::MisskeyPublisher,
|
publish::MisskeyPublisher,
|
||||||
selection::{telegram::get_chat_ref, AnySelector, SelectorExt, TelegramSelector},
|
selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::{
|
use futures::{SinkExt, StreamExt, TryStreamExt, channel::mpsc::channel, future::Either, sink::unfold};
|
||||||
channel::mpsc::channel, future::Either, sink::unfold, SinkExt, StreamExt, TryStreamExt,
|
|
||||||
};
|
|
||||||
use model::SampleModelExt;
|
use model::SampleModelExt;
|
||||||
|
|
||||||
mod config;
|
mod config;
|
||||||
|
@ -42,7 +40,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
}
|
}
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
let publisher = resolve_publisher(&mut cfg).await?;
|
let publisher = resolve_publisher(&mut cfg).await?;
|
||||||
|
|
||||||
let api = Arc::new(Api::new(cfg.bot_token.clone()));
|
let api = Arc::new(Api::new(cfg.bot_token.clone()));
|
||||||
|
|
||||||
|
@ -68,30 +66,43 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
|
||||||
let cfg_clone = cfg.clone();
|
let cfg_clone = cfg.clone();
|
||||||
|
|
||||||
let mut model = AnySelector::new()
|
let mut model = TelegramSelector::new(
|
||||||
.filter(
|
api,
|
||||||
model::GPTSampleModel::new(
|
chat,
|
||||||
cfg.python_path.clone(),
|
Box::pin(unfold((), move |_, chat_ref| {
|
||||||
cfg.gpt_code_path.clone(),
|
let mut cfg_clone = cfg_clone.clone();
|
||||||
vec![
|
async move {
|
||||||
"generate_unconditional_samples.py".to_string(),
|
if let ChatRef::Id(id) = &chat_ref {
|
||||||
"--model_name".to_string(),
|
cfg_clone.chat_ref = id.clone();
|
||||||
cfg.model_name.clone(),
|
let _ = cfg_clone.save(CONFIG_PATH);
|
||||||
"--temperature".to_string(),
|
}
|
||||||
cfg.temperature.clone(),
|
Ok::<_, Infallible>(())
|
||||||
"--top_k".to_string(),
|
}
|
||||||
cfg.top_k.clone(),
|
})),
|
||||||
"--nsamples".to_string(),
|
)
|
||||||
"1".to_string(),
|
.filter(
|
||||||
],
|
model::GPTSampleModel::new(
|
||||||
)
|
cfg.python_path.clone(),
|
||||||
.into_stream()
|
cfg.gpt_code_path.clone(),
|
||||||
.try_filter(|message| {
|
vec![
|
||||||
let criteria = !message.is_empty() && message.chars().count() <= 4096;
|
"generate_unconditional_samples.py".to_string(),
|
||||||
async move { criteria }
|
"--model_name".to_string(),
|
||||||
}),
|
cfg.model_name.clone(),
|
||||||
|
"--temperature".to_string(),
|
||||||
|
cfg.temperature.clone(),
|
||||||
|
"--top_k".to_string(),
|
||||||
|
cfg.top_k.clone(),
|
||||||
|
"--nsamples".to_string(),
|
||||||
|
"1".to_string(),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
.map_err(|e| Box::new(e) as Box<dyn Error>);
|
.into_stream()
|
||||||
|
.try_filter(|message| {
|
||||||
|
let criteria = !message.is_empty() && message.chars().count() <= 4096;
|
||||||
|
async move { criteria }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn Error>);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
|
@ -134,11 +145,9 @@ async fn resolve_publisher(
|
||||||
config: &mut config::Config,
|
config: &mut config::Config,
|
||||||
) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> {
|
) -> Result<Either<MisskeyPublisher, MastodonPublisher>, Box<dyn Error>> {
|
||||||
let publisher = match &config.publisher {
|
let publisher = match &config.publisher {
|
||||||
config::Publisher::Misskey(cfg) => Either::Left(MisskeyPublisher::new(
|
config::Publisher::Misskey(cfg) => {
|
||||||
&cfg.base_url,
|
Either::Left(MisskeyPublisher::new(&cfg.base_url, cfg.token.clone(), cfg.visibility)?)
|
||||||
cfg.token.clone(),
|
}
|
||||||
cfg.visibility,
|
|
||||||
)?),
|
|
||||||
config::Publisher::Mastodon(cfg) => {
|
config::Publisher::Mastodon(cfg) => {
|
||||||
let app = AppBuilder {
|
let app = AppBuilder {
|
||||||
client_name: "izzilis",
|
client_name: "izzilis",
|
||||||
|
@ -149,7 +158,7 @@ async fn resolve_publisher(
|
||||||
|
|
||||||
let mut registration = Registration::new(cfg.base_url.clone());
|
let mut registration = Registration::new(cfg.base_url.clone());
|
||||||
registration.register(app)?;
|
registration.register(app)?;
|
||||||
let vis = cfg.visibility;
|
let vis = cfg.visibility;
|
||||||
|
|
||||||
let mastodon = if let Some(data) = cfg.token.clone() {
|
let mastodon = if let Some(data) = cfg.token.clone() {
|
||||||
Mastodon::from_data(data.clone())
|
Mastodon::from_data(data.clone())
|
||||||
|
@ -165,7 +174,7 @@ async fn resolve_publisher(
|
||||||
config.publisher = Publisher::Mastodon(FediverseConfig {
|
config.publisher = Publisher::Mastodon(FediverseConfig {
|
||||||
base_url: cfg.base_url.clone(),
|
base_url: cfg.base_url.clone(),
|
||||||
token: Some(fedi.data.clone()),
|
token: Some(fedi.data.clone()),
|
||||||
visibility: vis.clone(),
|
visibility: vis.clone(),
|
||||||
});
|
});
|
||||||
config.save(CONFIG_PATH)?;
|
config.save(CONFIG_PATH)?;
|
||||||
|
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
use std::convert::Infallible;
|
|
||||||
|
|
||||||
use futures::future::BoxFuture;
|
|
||||||
|
|
||||||
use super::Selector;
|
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
|
||||||
pub struct AnySelector;
|
|
||||||
|
|
||||||
impl Selector for AnySelector {
|
|
||||||
type Error = Box<Infallible>;
|
|
||||||
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
|
|
||||||
|
|
||||||
fn review(&self, message: String) -> Self::Response {
|
|
||||||
Box::pin(async move { Ok(message.len() != 0) })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AnySelector {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
use async_std::io::stdin;
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
|
||||||
|
use super::Selector;
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub struct ConsoleSelector;
|
||||||
|
|
||||||
|
impl Selector for ConsoleSelector {
|
||||||
|
type Error = Box<dyn Error>;
|
||||||
|
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
|
||||||
|
|
||||||
|
fn review(&self, message: String) -> Self::Response {
|
||||||
|
println!("{} (y/N) ", message);
|
||||||
|
let stdin = stdin();
|
||||||
|
Box::pin(async move {
|
||||||
|
let mut buffer = String::new();
|
||||||
|
stdin.read_line(&mut buffer).await?;
|
||||||
|
Ok(
|
||||||
|
match buffer.chars().next().unwrap_or('n').to_ascii_lowercase() {
|
||||||
|
'y' => true,
|
||||||
|
_ => false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,9 +2,9 @@ use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
mod any;
|
mod console;
|
||||||
pub mod telegram;
|
pub mod telegram;
|
||||||
pub use any::AnySelector;
|
pub use console::ConsoleSelector;
|
||||||
pub use telegram::TelegramSelector;
|
pub use telegram::TelegramSelector;
|
||||||
|
|
||||||
pub trait Selector {
|
pub trait Selector {
|
||||||
|
|
Loading…
Reference in New Issue