Finished implementing TelegramSelector
This commit is contained in:
parent
282aee09e0
commit
97f84dd49b
|
@ -4,3 +4,4 @@ gpt/
|
||||||
gpt
|
gpt
|
||||||
bot_config.json
|
bot_config.json
|
||||||
fediverse.toml
|
fediverse.toml
|
||||||
|
secrets
|
|
@ -1,5 +1,7 @@
|
||||||
# This file is automatically @generated by Cargo.
|
# This file is automatically @generated by Cargo.
|
||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "addr2line"
|
name = "addr2line"
|
||||||
version = "0.15.2"
|
version = "0.15.2"
|
||||||
|
@ -15,6 +17,12 @@ version = "1.0.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anyhow"
|
||||||
|
version = "1.0.41"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "15af2628f6890fe2609a3b91bef4c83450512802e59489f9c1cb1fa5df064a61"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-channel"
|
name = "async-channel"
|
||||||
version = "1.6.1"
|
version = "1.6.1"
|
||||||
|
@ -40,17 +48,6 @@ dependencies = [
|
||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "async-fs"
|
|
||||||
version = "1.5.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8b3ca4f8ff117c37c278a2f7415ce9be55560b846b5bc4412aaa5d29c1c3dae2"
|
|
||||||
dependencies = [
|
|
||||||
"async-lock",
|
|
||||||
"blocking",
|
|
||||||
"futures-lite",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-global-executor"
|
name = "async-global-executor"
|
||||||
version = "2.0.2"
|
version = "2.0.2"
|
||||||
|
@ -104,17 +101,6 @@ dependencies = [
|
||||||
"event-listener",
|
"event-listener",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "async-net"
|
|
||||||
version = "1.6.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "5373304df79b9b4395068fb080369ec7178608827306ce4d081cba51cac551df"
|
|
||||||
dependencies = [
|
|
||||||
"async-io",
|
|
||||||
"blocking",
|
|
||||||
"futures-lite",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-process"
|
name = "async-process"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
@ -1169,6 +1155,7 @@ checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
|
||||||
name = "izzilis"
|
name = "izzilis"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"async-std",
|
"async-std",
|
||||||
"chrono",
|
"chrono",
|
||||||
"elefren",
|
"elefren",
|
||||||
|
@ -1177,8 +1164,8 @@ dependencies = [
|
||||||
"rand 0.8.4",
|
"rand 0.8.4",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"smol",
|
|
||||||
"telegram-bot",
|
"telegram-bot",
|
||||||
|
"tokio 0.2.25",
|
||||||
"uuid 0.8.2",
|
"uuid 0.8.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1343,12 +1330,35 @@ dependencies = [
|
||||||
"kernel32-sys",
|
"kernel32-sys",
|
||||||
"libc",
|
"libc",
|
||||||
"log 0.4.14",
|
"log 0.4.14",
|
||||||
"miow",
|
"miow 0.2.2",
|
||||||
"net2",
|
"net2",
|
||||||
"slab",
|
"slab",
|
||||||
"winapi 0.2.8",
|
"winapi 0.2.8",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mio-named-pipes"
|
||||||
|
version = "0.1.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0840c1c50fd55e521b247f949c241c9997709f23bd7f023b9762cd561e935656"
|
||||||
|
dependencies = [
|
||||||
|
"log 0.4.14",
|
||||||
|
"mio",
|
||||||
|
"miow 0.3.7",
|
||||||
|
"winapi 0.3.9",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mio-uds"
|
||||||
|
version = "0.6.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0"
|
||||||
|
dependencies = [
|
||||||
|
"iovec",
|
||||||
|
"libc",
|
||||||
|
"mio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "miow"
|
name = "miow"
|
||||||
version = "0.2.2"
|
version = "0.2.2"
|
||||||
|
@ -1361,6 +1371,15 @@ dependencies = [
|
||||||
"ws2_32-sys",
|
"ws2_32-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "miow"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21"
|
||||||
|
dependencies = [
|
||||||
|
"winapi 0.3.9",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multipart"
|
name = "multipart"
|
||||||
version = "0.16.1"
|
version = "0.16.1"
|
||||||
|
@ -2208,24 +2227,6 @@ dependencies = [
|
||||||
"maybe-uninit",
|
"maybe-uninit",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "smol"
|
|
||||||
version = "1.2.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "85cf3b5351f3e783c1d79ab5fc604eeed8b8ae9abd36b166e8b87a089efd85e4"
|
|
||||||
dependencies = [
|
|
||||||
"async-channel",
|
|
||||||
"async-executor",
|
|
||||||
"async-fs",
|
|
||||||
"async-io",
|
|
||||||
"async-lock",
|
|
||||||
"async-net",
|
|
||||||
"async-process",
|
|
||||||
"blocking",
|
|
||||||
"futures-lite",
|
|
||||||
"once_cell",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "socket2"
|
name = "socket2"
|
||||||
version = "0.3.19"
|
version = "0.3.19"
|
||||||
|
@ -2384,10 +2385,17 @@ dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"iovec",
|
"iovec",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
|
"libc",
|
||||||
"memchr",
|
"memchr",
|
||||||
"mio",
|
"mio",
|
||||||
|
"mio-named-pipes",
|
||||||
|
"mio-uds",
|
||||||
|
"num_cpus",
|
||||||
"pin-project-lite 0.1.12",
|
"pin-project-lite 0.1.12",
|
||||||
|
"signal-hook-registry",
|
||||||
"slab",
|
"slab",
|
||||||
|
"tokio-macros",
|
||||||
|
"winapi 0.3.9",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2432,6 +2440,17 @@ dependencies = [
|
||||||
"log 0.4.14",
|
"log 0.4.14",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-macros"
|
||||||
|
version = "0.2.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e44da00bfc73a25f814cd8d7e57a68a5c31b74b3152a0a1d1f590c97ed06265a"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-reactor"
|
name = "tokio-reactor"
|
||||||
version = "0.1.12"
|
version = "0.1.12"
|
||||||
|
@ -2718,6 +2737,9 @@ name = "uuid"
|
||||||
version = "0.8.2"
|
version = "0.8.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
|
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom 0.2.3",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "value-bag"
|
name = "value-bag"
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
[package]
|
[package]
|
||||||
name = "izzilis"
|
name = "izzilis"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
authors = ["Emilis <grinding@graduate.org>"]
|
authors = ["Emilis <grinding@graduate.org>", "Izzy Swart <zenerboson@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-std = { version = "1.9.0", features = ["unstable"] }
|
async-std = { version = "1.9.0", features = ["unstable"] }
|
||||||
chrono = "0.4.19"
|
chrono = "0.4.19"
|
||||||
|
@ -15,6 +13,7 @@ futures-timer = "3.0.2"
|
||||||
rand = "0.8.4"
|
rand = "0.8.4"
|
||||||
serde = "1.0.126"
|
serde = "1.0.126"
|
||||||
serde_json = "1.0.64"
|
serde_json = "1.0.64"
|
||||||
smol = "1.2.5"
|
tokio = { version = "0.2", features = ["full"] }
|
||||||
telegram-bot = "0.8.0"
|
telegram-bot = "0.8.0"
|
||||||
uuid = "0.8.2"
|
uuid = { version = "0.8.2", features = ["v4"] }
|
||||||
|
anyhow = "1.0.41"
|
||||||
|
|
111
src/bot.rs
111
src/bot.rs
|
@ -1,111 +0,0 @@
|
||||||
use rand::Rng;
|
|
||||||
use std::error::Error;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
model::{self, SampleModel},
|
|
||||||
publish, selection,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub struct IzzilisBot<T: SampleModel, U: publish::Publisher, V: selection::Selector> {
|
|
||||||
model: T,
|
|
||||||
publisher: U,
|
|
||||||
selector: V,
|
|
||||||
loaded_samples: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, U, V> IzzilisBot<T, U, V>
|
|
||||||
where
|
|
||||||
T: model::SampleModel,
|
|
||||||
U: publish::Publisher,
|
|
||||||
V: selection::Selector,
|
|
||||||
{
|
|
||||||
pub fn new(model: T, publisher: U, selector: V) -> IzzilisBot<T, U, V> {
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
publisher,
|
|
||||||
loaded_samples: Vec::new(),
|
|
||||||
selector: selector,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn generate_samples(&mut self) -> Result<(), Box<dyn Error>> {
|
|
||||||
// let lines = self.model.generate_sample_lines()?;
|
|
||||||
// for line in lines {
|
|
||||||
// self.selector.send_for_review(line)?;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// self.loaded_samples = self.selector.collect_selected_samples(); // wtf happens to the original self.loaded_samples???????
|
|
||||||
|
|
||||||
// Ok(())
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn publish(&mut self) -> Result<(), U::Error> {
|
|
||||||
if self.loaded_samples.len() < 5 {
|
|
||||||
// Refresh samples. Either none have been generated so far,
|
|
||||||
// or generated ones are stale.
|
|
||||||
//
|
|
||||||
// This is a shit solution, but I'm going with it for v1
|
|
||||||
// purely because I don't know the language well enough to be
|
|
||||||
// confident in doing this via threads. Yet.
|
|
||||||
// TODO handle errors here
|
|
||||||
let _ = self.generate_samples();
|
|
||||||
}
|
|
||||||
let sample_index = rand::thread_rng().gen_range(0..self.loaded_samples.len() - 1);
|
|
||||||
let content = self.loaded_samples[sample_index].clone();
|
|
||||||
self.loaded_samples.remove(sample_index);
|
|
||||||
|
|
||||||
self.publisher.publish(content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(tests)]
|
|
||||||
mod tests {
|
|
||||||
use std::io::{self, ErrorKind};
|
|
||||||
|
|
||||||
use crate::{generator, model, publish};
|
|
||||||
|
|
||||||
struct fake_sampler {
|
|
||||||
should_ok: bool,
|
|
||||||
ok_str: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct fake_publisher {
|
|
||||||
should_ok: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl model::SampleModel for fake_sampler {
|
|
||||||
fn get_sample(&self) -> Result<String, io::Error> {
|
|
||||||
if self.should_ok {
|
|
||||||
return Ok(self.ok_str.clone());
|
|
||||||
}
|
|
||||||
Err(io::Error::new(ErrorKind::NotFound, "error"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl publish::Publisher for fake_publisher {
|
|
||||||
fn publish(&self, content: String) -> Option<Box<dyn std::error::Error>> {
|
|
||||||
if self.should_ok {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
Some(Box::new(io::Error::new(ErrorKind::NotFound, "error")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn generate_samples_populates() {
|
|
||||||
let model_ok_string = String::from("model_ok");
|
|
||||||
let model = fake_sampler {
|
|
||||||
should_ok: true,
|
|
||||||
ok_str: model_ok_string,
|
|
||||||
};
|
|
||||||
let gen = generator::Generator::new(model);
|
|
||||||
let publish = fake_publisher { should_ok: true };
|
|
||||||
let mut bot = super::IzzilisBot::new(gen, publish);
|
|
||||||
|
|
||||||
match bot.publish() {
|
|
||||||
Some(_) => panic!("publish failed"),
|
|
||||||
None => (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
112
src/config.rs
112
src/config.rs
|
@ -1,36 +1,29 @@
|
||||||
use std::error::Error;
|
use std::{error::Error, path::Path};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use telegram_bot::ChatId;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
python_path: String,
|
pub python_path: String,
|
||||||
model_name: String,
|
pub model_name: String,
|
||||||
temperature: String,
|
pub temperature: String,
|
||||||
top_k: String,
|
pub top_k: String,
|
||||||
gpt_code_path: String,
|
pub gpt_code_path: String,
|
||||||
fediverse_base_url: String,
|
pub fediverse_base_url: String,
|
||||||
interval_seconds: MinMax,
|
pub interval_seconds: MinMax,
|
||||||
|
pub bot_token: String,
|
||||||
|
pub chat_ref: ChatId,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct MinMax {
|
pub struct MinMax {
|
||||||
min: u64,
|
pub min: u64,
|
||||||
max: u64,
|
pub max: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MinMax {
|
impl Default for Config {
|
||||||
pub fn min(&self) -> u64 {
|
fn default() -> Self {
|
||||||
self.min
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max(&self) -> u64 {
|
|
||||||
self.max
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
pub fn default() -> Config {
|
|
||||||
Config {
|
Config {
|
||||||
python_path: String::from("/usr/bin/python3"),
|
python_path: String::from("/usr/bin/python3"),
|
||||||
model_name: String::from("117M"),
|
model_name: String::from("117M"),
|
||||||
|
@ -42,60 +35,23 @@ impl Config {
|
||||||
min: 60 * 30,
|
min: 60 * 30,
|
||||||
max: 60 * 90,
|
max: 60 * 90,
|
||||||
},
|
},
|
||||||
|
bot_token: "".to_owned(),
|
||||||
|
chat_ref: ChatId::new(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
pub fn from(path: String) -> Result<Config, Box<dyn Error>> {
|
|
||||||
let file_bytes = std::fs::read(path)?;
|
impl Config {
|
||||||
match serde_json::from_slice(&file_bytes) {
|
pub fn from<P: AsRef<Path>>(path: P) -> Result<Config, Box<dyn Error>> {
|
||||||
Ok(res) => Ok(res),
|
let file_bytes = std::fs::read(path)?;
|
||||||
Err(err) => Err(Box::new(err)),
|
|
||||||
}
|
Ok(serde_json::from_slice(&file_bytes)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save(&self, path: String) -> Option<Box<dyn Error>> {
|
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn Error>> {
|
||||||
let cfg_json = match serde_json::to_vec(self) {
|
let cfg_json = serde_json::to_vec(self)?;
|
||||||
Ok(res) => res,
|
std::fs::write(path, &cfg_json)?;
|
||||||
Err(err) => return Some(Box::new(err)),
|
|
||||||
};
|
Ok(())
|
||||||
|
}
|
||||||
match std::fs::write(path, &cfg_json) {
|
|
||||||
Ok(_) => None,
|
|
||||||
Err(err) => Some(Box::new(err)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the config's python path.
|
|
||||||
pub fn python_path(&self) -> String {
|
|
||||||
self.python_path.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the config's model name.
|
|
||||||
pub fn model_name(&self) -> String {
|
|
||||||
self.model_name.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the config's temperature.
|
|
||||||
pub fn temperature(&self) -> String {
|
|
||||||
self.temperature.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the config's top k.
|
|
||||||
pub fn top_k(&self) -> String {
|
|
||||||
self.top_k.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the config's gpt code path.
|
|
||||||
pub fn gpt_code_path(&self) -> String {
|
|
||||||
self.gpt_code_path.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fediverse_base_url(&self) -> String {
|
|
||||||
self.fediverse_base_url.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the config's inverval seconds.
|
|
||||||
pub fn interval_seconds(&self) -> &MinMax {
|
|
||||||
&self.interval_seconds
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
144
src/main.rs
144
src/main.rs
|
@ -1,86 +1,88 @@
|
||||||
use std::{error::Error, process, time::Duration};
|
use std::{convert::Infallible, error::Error, sync::Arc};
|
||||||
|
|
||||||
use chrono::Local;
|
use telegram_bot::{Api, ChatRef};
|
||||||
use rand::Rng;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector};
|
||||||
bot::IzzilisBot,
|
|
||||||
publish::FediversePublisher,
|
|
||||||
selection::{ConsoleSelector, SelectorExt},
|
|
||||||
};
|
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::{sink::unfold, StreamExt};
|
||||||
use futures_timer::Delay;
|
|
||||||
use model::{EmptyModel, SampleModelExt};
|
use model::{EmptyModel, SampleModelExt};
|
||||||
|
|
||||||
mod bot;
|
|
||||||
mod config;
|
mod config;
|
||||||
mod model;
|
mod model;
|
||||||
mod publish;
|
|
||||||
mod selection;
|
mod selection;
|
||||||
|
|
||||||
const CONFIG_PATH: &str = "bot_config.json";
|
const CONFIG_PATH: &str = "bot_config.json";
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn Error>> {
|
#[tokio::main]
|
||||||
smol::block_on(async {
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let cfg = match config::Config::from(CONFIG_PATH.to_string()) {
|
let mut cfg = match config::Config::from(CONFIG_PATH) {
|
||||||
Ok(cfg) => cfg,
|
Ok(cfg) => Ok(cfg),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
println!(
|
println!(
|
||||||
"Failed reading config at [{}], writing default",
|
"Failed reading config at [{}], writing default",
|
||||||
CONFIG_PATH
|
CONFIG_PATH
|
||||||
);
|
);
|
||||||
match config::Config::default().save(CONFIG_PATH.to_string()) {
|
config::Config::default().save(CONFIG_PATH)?;
|
||||||
Some(err) => println!("Failed writing file to {}: {}", CONFIG_PATH, err),
|
Err(anyhow::anyhow!("Invalid configuration"))
|
||||||
None => (),
|
|
||||||
}
|
|
||||||
process::exit(1);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// let mut gpt_model = ConsoleSelector.filter(
|
|
||||||
// model::GPTSampleModel::new(
|
|
||||||
// cfg.python_path(),
|
|
||||||
// cfg.gpt_code_path(),
|
|
||||||
// vec![
|
|
||||||
// "generate_unconditional_samples.py".to_string(),
|
|
||||||
// "--model_name".to_string(),
|
|
||||||
// cfg.model_name(),
|
|
||||||
// "--temperature".to_string(),
|
|
||||||
// cfg.temperature(),
|
|
||||||
// "--top_k".to_string(),
|
|
||||||
// cfg.top_k(),
|
|
||||||
// "--nsamples".to_string(),
|
|
||||||
// "1".to_string(),
|
|
||||||
// ],
|
|
||||||
// )
|
|
||||||
// .into_stream()
|
|
||||||
// .take(10),
|
|
||||||
// );
|
|
||||||
let mut gpt_model = ConsoleSelector.filter(EmptyModel.into_stream().take(5));
|
|
||||||
|
|
||||||
while let Some(Ok(sample)) = gpt_model.next().await {
|
|
||||||
println!("{}", sample);
|
|
||||||
}
|
}
|
||||||
|
}?;
|
||||||
|
|
||||||
return Ok(());
|
// let mut gpt_model = ConsoleSelector.filter(
|
||||||
// let publisher = FediversePublisher::new(cfg.fediverse_base_url())?;
|
// model::GPTSampleModel::new(
|
||||||
// // let publisher = ConsolePublisher::new();
|
// cfg.python_path(),
|
||||||
// let console_selector = ConsoleSelector::new();
|
// cfg.gpt_code_path(),
|
||||||
// let mut bot = IzzilisBot::new(gen, publisher, console_selector);
|
// vec![
|
||||||
// bot.generate_samples();
|
// "generate_unconditional_samples.py".to_string(),
|
||||||
|
// "--model_name".to_string(),
|
||||||
|
// cfg.model_name(),
|
||||||
|
// "--temperature".to_string(),
|
||||||
|
// cfg.temperature(),
|
||||||
|
// "--top_k".to_string(),
|
||||||
|
// cfg.top_k(),
|
||||||
|
// "--nsamples".to_string(),
|
||||||
|
// "1".to_string(),
|
||||||
|
// ],
|
||||||
|
// )
|
||||||
|
// .into_stream()
|
||||||
|
// .take(10),
|
||||||
|
// );
|
||||||
|
|
||||||
// let cfg_interval = cfg.interval_seconds();
|
let api = Arc::new(Api::new(
|
||||||
// loop {
|
std::env::var("TELEGRAM_BOT_KEY").expect("bot key not present"),
|
||||||
// let wait_seconds = rand::thread_rng().gen_range(cfg_interval.min()..cfg_interval.max());
|
));
|
||||||
// let wait_time = Duration::from_secs(wait_seconds);
|
|
||||||
// let now = Local::now();
|
let chat = get_chat_ref(api.clone(), None)
|
||||||
// println!("[{}] Next post is in [{}] seconds", now, wait_seconds);
|
.await
|
||||||
// Delay::new(wait_time).await;
|
.expect("could not get chat ref");
|
||||||
// match bot.publish() {
|
|
||||||
// Err(err) => println!("Got error from publish: [{}]; continuing", err),
|
let chat_ref = chat.lock().await.clone();
|
||||||
// Ok(()) => println!("publish() call successful"),
|
|
||||||
// }
|
if let ChatRef::Id(id) = chat_ref {
|
||||||
// }
|
cfg.chat_ref = id;
|
||||||
})
|
cfg.save(CONFIG_PATH)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cfg_clone = cfg.clone();
|
||||||
|
let mut gpt_model = TelegramSelector::new(
|
||||||
|
api,
|
||||||
|
chat,
|
||||||
|
Box::pin(unfold((), move |_, chat_ref| {
|
||||||
|
let mut cfg_clone = cfg_clone.clone();
|
||||||
|
async move {
|
||||||
|
if let ChatRef::Id(id) = &chat_ref {
|
||||||
|
cfg_clone.chat_ref = id.clone();
|
||||||
|
let _ = cfg_clone.save(CONFIG_PATH);
|
||||||
|
}
|
||||||
|
Ok::<_, Infallible>(())
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.filter(EmptyModel.into_stream())
|
||||||
|
.take(5);
|
||||||
|
|
||||||
|
while let Some(Ok(sample)) = gpt_model.next().await {
|
||||||
|
println!("{}", sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,83 +0,0 @@
|
||||||
use std::{convert::Infallible, error::Error};
|
|
||||||
|
|
||||||
use elefren::{
|
|
||||||
helpers::{cli, toml},
|
|
||||||
scopes::Scopes,
|
|
||||||
status_builder::Visibility,
|
|
||||||
Language, Mastodon, MastodonClient, Registration, StatusBuilder,
|
|
||||||
};
|
|
||||||
|
|
||||||
const FEDIVERSE_TOML_PATH: &str = "fediverse.toml";
|
|
||||||
|
|
||||||
pub trait Publisher {
|
|
||||||
type Error;
|
|
||||||
|
|
||||||
fn publish(&self, content: String) -> Result<(), Self::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct FediversePublisher {
|
|
||||||
client: Mastodon,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ConsolePublisher;
|
|
||||||
|
|
||||||
impl Publisher for ConsolePublisher {
|
|
||||||
type Error = Infallible;
|
|
||||||
|
|
||||||
fn publish(&self, content: String) -> Result<(), Self::Error> {
|
|
||||||
println!("Publishing content to stdout: {}", content);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConsolePublisher {
|
|
||||||
pub fn new() -> ConsolePublisher {
|
|
||||||
ConsolePublisher {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FediversePublisher {
|
|
||||||
pub fn new(fedi_url: String) -> Result<FediversePublisher, Box<dyn Error>> {
|
|
||||||
Ok(Self {
|
|
||||||
client: toml::from_file(FEDIVERSE_TOML_PATH)
|
|
||||||
.map(|data| Ok(Mastodon::from(data)))
|
|
||||||
.unwrap_or_else(|_| register(fedi_url))?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Publisher for FediversePublisher {
|
|
||||||
type Error = Box<dyn Error>;
|
|
||||||
|
|
||||||
fn publish(&self, content: String) -> Result<(), Self::Error> {
|
|
||||||
let status = StatusBuilder::new()
|
|
||||||
.status(&content)
|
|
||||||
// .visibility(Visibility::Direct)
|
|
||||||
.visibility(Visibility::Public)
|
|
||||||
.sensitive(false)
|
|
||||||
.language(Language::Eng)
|
|
||||||
.build()
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
|
|
||||||
|
|
||||||
println!("Posting status [{}] to fediverse", &content);
|
|
||||||
|
|
||||||
self.client
|
|
||||||
.new_status(status)
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register(fedi_url: String) -> Result<Mastodon, Box<dyn Error>> {
|
|
||||||
let registration = Registration::new(fedi_url)
|
|
||||||
.client_name("izzilis")
|
|
||||||
.scopes(Scopes::write_all())
|
|
||||||
.build()?;
|
|
||||||
let fediverse = cli::authenticate(registration)?;
|
|
||||||
|
|
||||||
// Save app data for using on the next run.
|
|
||||||
toml::to_file(&*fediverse, FEDIVERSE_TOML_PATH.to_string())?;
|
|
||||||
|
|
||||||
Ok(fediverse)
|
|
||||||
}
|
|
|
@ -1,8 +1,9 @@
|
||||||
use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
|
use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
|
||||||
|
|
||||||
mod console;
|
mod console;
|
||||||
mod telegram;
|
pub mod telegram;
|
||||||
pub use console::ConsoleSelector;
|
pub use console::ConsoleSelector;
|
||||||
|
pub use telegram::TelegramSelector;
|
||||||
|
|
||||||
pub trait Selector {
|
pub trait Selector {
|
||||||
type Error;
|
type Error;
|
||||||
|
|
|
@ -1,14 +1,20 @@
|
||||||
use async_std::future::pending;
|
use futures::{
|
||||||
use futures::{channel::oneshot::Sender, future::BoxFuture, lock::Mutex, Sink, StreamExt};
|
channel::oneshot::{self, Sender},
|
||||||
use std::{collections::HashMap, sync::Arc};
|
future::BoxFuture,
|
||||||
|
lock::Mutex,
|
||||||
|
Sink, SinkExt, StreamExt,
|
||||||
|
};
|
||||||
|
use std::{collections::HashMap, fmt::Debug, sync::Arc};
|
||||||
use telegram_bot::{
|
use telegram_bot::{
|
||||||
self, requests, Api, CallbackQuery, ChatRef, InlineKeyboardButton, InlineKeyboardMarkup,
|
self, types::requests::answer_callback_query::CanAnswerCallbackQuery, Api, ChatRef,
|
||||||
SendMessage, Update, UpdateKind,
|
DeleteMessage, EditMessageReplyMarkup, InlineKeyboardButton, InlineKeyboardMarkup, Message,
|
||||||
|
MessageKind, SendMessage, ToMessageId, ToSourceChat, Update, UpdateKind,
|
||||||
};
|
};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::Selector;
|
use super::Selector;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct TelegramSelector {
|
pub struct TelegramSelector {
|
||||||
client: Arc<Api>,
|
client: Arc<Api>,
|
||||||
chat_ref: Arc<Mutex<ChatRef>>,
|
chat_ref: Arc<Mutex<ChatRef>>,
|
||||||
|
@ -16,32 +22,86 @@ pub struct TelegramSelector {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TelegramSelector {
|
impl TelegramSelector {
|
||||||
pub fn new(api: Arc<Api>, chat_ref: Arc<Mutex<ChatRef>>) -> Self {
|
pub fn new<S: Sink<ChatRef> + Send + Unpin + 'static>(
|
||||||
|
api: Arc<Api>,
|
||||||
|
chat_ref: Arc<Mutex<ChatRef>>,
|
||||||
|
mut updates: S,
|
||||||
|
) -> Self
|
||||||
|
where
|
||||||
|
S::Error: Debug,
|
||||||
|
{
|
||||||
let pending: Arc<Mutex<HashMap<Uuid, Sender<bool>>>> = Arc::new(Mutex::new(HashMap::new()));
|
let pending: Arc<Mutex<HashMap<Uuid, Sender<bool>>>> = Arc::new(Mutex::new(HashMap::new()));
|
||||||
let stream = api.clone();
|
let api_clone = api.clone();
|
||||||
|
let chat_ref_clone = chat_ref.clone();
|
||||||
let pending_clone = pending.clone();
|
let pending_clone = pending.clone();
|
||||||
smol::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let stream = stream.stream();
|
let mut stream = api_clone.stream();
|
||||||
while let Some(Ok(Update {
|
while let Some(Ok(data)) = stream.next().await {
|
||||||
kind:
|
if let Update {
|
||||||
UpdateKind::CallbackQuery(CallbackQuery {
|
kind: UpdateKind::CallbackQuery(query),
|
||||||
data: Some(query), ..
|
..
|
||||||
}),
|
} = data
|
||||||
..
|
{
|
||||||
})) = stream.next().await
|
if let Some(data) = query.data.clone() {
|
||||||
{
|
let uuid_bytes = data.as_bytes().get(0..32);
|
||||||
let uuid_bytes = query.as_bytes().get(0..32);
|
let bool_byte = data.as_bytes().get(32);
|
||||||
let bool_byte = query.as_bytes().get(32);
|
if let (Some(uuid), Some(keep)) = (uuid_bytes, bool_byte) {
|
||||||
if let (Some(uuid), Some(keep)) = (uuid_bytes, bool_byte) {
|
let uuid = Uuid::parse_str(&String::from_utf8_lossy(uuid));
|
||||||
let uuid = Uuid::parse_str(&String::from_utf8_lossy(uuid));
|
if let Ok(uuid) = uuid {
|
||||||
if let Ok(uuid) = uuid {
|
if let Some(sender) = pending_clone.lock().await.remove(&uuid) {
|
||||||
if let Some(sender) = pending_clone.lock().await.get(&uuid) {
|
let keep = match *keep as char {
|
||||||
sender.send(match *keep as char {
|
't' => true,
|
||||||
't' => true,
|
_ => false,
|
||||||
_ => false,
|
};
|
||||||
});
|
let _ = sender.send(keep);
|
||||||
|
let _ = api_clone
|
||||||
|
.send(query.answer(if keep {
|
||||||
|
"Kept!"
|
||||||
|
} else {
|
||||||
|
"Discarded!"
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Some(message) = query.message {
|
||||||
|
let _ = api_clone
|
||||||
|
.send(EditMessageReplyMarkup::new(
|
||||||
|
message.to_source_chat(),
|
||||||
|
message.to_message_id(),
|
||||||
|
None::<InlineKeyboardMarkup>,
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
if !keep {
|
||||||
|
let _ = api_clone
|
||||||
|
.send(DeleteMessage::new(
|
||||||
|
message.to_source_chat(),
|
||||||
|
message.to_message_id(),
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if let Update {
|
||||||
|
kind:
|
||||||
|
UpdateKind::Message(Message {
|
||||||
|
chat,
|
||||||
|
kind: MessageKind::Text { data, .. },
|
||||||
|
..
|
||||||
|
}),
|
||||||
|
..
|
||||||
|
} = data
|
||||||
|
{
|
||||||
|
if data.starts_with("/setmain") {
|
||||||
|
let new_chat_ref = ChatRef::from_chat_id(chat.id());
|
||||||
|
if let Err(e) = updates.send(new_chat_ref.clone()).await {
|
||||||
|
println!("failed to send updated chat ref: {:?}", e);
|
||||||
|
}
|
||||||
|
*chat_ref_clone.lock().await = new_chat_ref;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -64,10 +124,58 @@ impl Selector for TelegramSelector {
|
||||||
let pending = self.pending.clone();
|
let pending = self.pending.clone();
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let chat_ref = chat_ref.lock().await.clone();
|
let chat_ref = chat_ref.lock().await.clone();
|
||||||
let message = SendMessage::new(chat_ref, message).reply_markup(
|
let uuid = Uuid::new_v4();
|
||||||
InlineKeyboardMarkup::new().add_row(vec![InlineKeyboardButton::callback("Keep")]),
|
let generate_callback = |keep| {
|
||||||
);
|
let mut buffer = vec![0u8; 33];
|
||||||
|
uuid.to_simple().encode_lower(&mut buffer);
|
||||||
|
buffer[32] = if keep { 't' } else { 'f' } as u8;
|
||||||
|
String::from_utf8_lossy(&buffer).to_string()
|
||||||
|
};
|
||||||
|
let mut message = SendMessage::new(chat_ref, message);
|
||||||
|
message.reply_markup({
|
||||||
|
let mut kb = InlineKeyboardMarkup::new();
|
||||||
|
kb.add_row(vec![
|
||||||
|
InlineKeyboardButton::callback("Keep", generate_callback(true)),
|
||||||
|
InlineKeyboardButton::callback("Discard", generate_callback(false)),
|
||||||
|
]);
|
||||||
|
kb
|
||||||
|
});
|
||||||
|
let (sender, receiver) = oneshot::channel();
|
||||||
|
pending.lock().await.insert(uuid, sender);
|
||||||
|
client.send(message).await?;
|
||||||
|
Ok(receiver.await.unwrap())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// /setmain
|
}
|
||||||
|
|
||||||
|
pub async fn get_chat_ref(
|
||||||
|
api: Arc<Api>,
|
||||||
|
mut chat_ref: Option<ChatRef>,
|
||||||
|
) -> Result<Arc<Mutex<ChatRef>>, telegram_bot::Error> {
|
||||||
|
let mut stream = api.stream();
|
||||||
|
|
||||||
|
if chat_ref.is_none() {
|
||||||
|
while let Some(data) = stream.next().await.transpose()? {
|
||||||
|
if let Update {
|
||||||
|
kind:
|
||||||
|
UpdateKind::Message(Message {
|
||||||
|
chat,
|
||||||
|
kind: MessageKind::Text { data, .. },
|
||||||
|
..
|
||||||
|
}),
|
||||||
|
..
|
||||||
|
} = data
|
||||||
|
{
|
||||||
|
if data.starts_with("/setmain") {
|
||||||
|
let chat_ref_temp = ChatRef::from_chat_id(chat.id());
|
||||||
|
chat_ref = Some(chat_ref_temp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_ref = Arc::new(Mutex::new(chat_ref.expect("bot API failed silently")));
|
||||||
|
|
||||||
|
Ok(chat_ref)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue