From 97f84dd49b993060d63ed091d5b811cf53ef64f0 Mon Sep 17 00:00:00 2001 From: Izzy Swart Date: Mon, 5 Jul 2021 18:39:09 -0700 Subject: [PATCH] Finished implementing TelegramSelector --- .gitignore | 3 +- Cargo.lock | 106 ++++++++++++++---------- Cargo.toml | 9 +- src/bot.rs | 111 ------------------------- src/config.rs | 112 ++++++++----------------- src/main.rs | 144 ++++++++++++++++---------------- src/publish.rs | 83 ------------------- src/selection/mod.rs | 3 +- src/selection/telegram.rs | 170 +++++++++++++++++++++++++++++++------- 9 files changed, 318 insertions(+), 423 deletions(-) delete mode 100644 src/bot.rs delete mode 100644 src/publish.rs diff --git a/.gitignore b/.gitignore index 031cd9e..c43e075 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ gpt/ gpt bot_config.json -fediverse.toml \ No newline at end of file +fediverse.toml +secrets \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index a49643b..12d0b9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "addr2line" version = "0.15.2" @@ -15,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "anyhow" +version = "1.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15af2628f6890fe2609a3b91bef4c83450512802e59489f9c1cb1fa5df064a61" + [[package]] name = "async-channel" version = "1.6.1" @@ -40,17 +48,6 @@ dependencies = [ "slab", ] -[[package]] -name = "async-fs" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b3ca4f8ff117c37c278a2f7415ce9be55560b846b5bc4412aaa5d29c1c3dae2" -dependencies = [ - "async-lock", - "blocking", - "futures-lite", -] - [[package]] name = "async-global-executor" version = "2.0.2" @@ -104,17 +101,6 @@ dependencies = [ "event-listener", ] -[[package]] -name = "async-net" -version = "1.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5373304df79b9b4395068fb080369ec7178608827306ce4d081cba51cac551df" -dependencies = [ - "async-io", - "blocking", - "futures-lite", -] - [[package]] name = "async-process" version = "1.1.0" @@ -1169,6 +1155,7 @@ checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" name = "izzilis" version = "0.1.0" dependencies = [ + "anyhow", "async-std", "chrono", "elefren", @@ -1177,8 +1164,8 @@ dependencies = [ "rand 0.8.4", "serde", "serde_json", - "smol", "telegram-bot", + "tokio 0.2.25", "uuid 0.8.2", ] @@ -1343,12 +1330,35 @@ dependencies = [ "kernel32-sys", "libc", "log 0.4.14", - "miow", + "miow 0.2.2", "net2", "slab", "winapi 0.2.8", ] +[[package]] +name = "mio-named-pipes" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0840c1c50fd55e521b247f949c241c9997709f23bd7f023b9762cd561e935656" +dependencies = [ + "log 0.4.14", + "mio", + "miow 0.3.7", + "winapi 0.3.9", +] + +[[package]] +name = "mio-uds" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" +dependencies = [ + "iovec", + "libc", + "mio", +] + [[package]] name = "miow" version = "0.2.2" @@ -1361,6 +1371,15 @@ dependencies = [ "ws2_32-sys", ] +[[package]] +name = "miow" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" +dependencies = [ + "winapi 0.3.9", +] + [[package]] name = "multipart" version = "0.16.1" @@ -2208,24 +2227,6 @@ dependencies = [ "maybe-uninit", ] -[[package]] -name = "smol" -version = "1.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85cf3b5351f3e783c1d79ab5fc604eeed8b8ae9abd36b166e8b87a089efd85e4" -dependencies = [ - "async-channel", - "async-executor", - "async-fs", - "async-io", - "async-lock", - "async-net", - "async-process", - "blocking", - "futures-lite", - "once_cell", -] - [[package]] name = "socket2" version = "0.3.19" @@ -2384,10 +2385,17 @@ dependencies = [ "futures-core", "iovec", "lazy_static", + "libc", "memchr", "mio", + "mio-named-pipes", + "mio-uds", + "num_cpus", "pin-project-lite 0.1.12", + "signal-hook-registry", "slab", + "tokio-macros", + "winapi 0.3.9", ] [[package]] @@ -2432,6 +2440,17 @@ dependencies = [ "log 0.4.14", ] +[[package]] +name = "tokio-macros" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e44da00bfc73a25f814cd8d7e57a68a5c31b74b3152a0a1d1f590c97ed06265a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-reactor" version = "0.1.12" @@ -2718,6 +2737,9 @@ name = "uuid" version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" +dependencies = [ + "getrandom 0.2.3", +] [[package]] name = "value-bag" diff --git a/Cargo.toml b/Cargo.toml index f06365c..46ef8dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,9 @@ [package] name = "izzilis" version = "0.1.0" -authors = ["Emilis "] +authors = ["Emilis ", "Izzy Swart "] edition = "2018" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] async-std = { version = "1.9.0", features = ["unstable"] } chrono = "0.4.19" @@ -15,6 +13,7 @@ futures-timer = "3.0.2" rand = "0.8.4" serde = "1.0.126" serde_json = "1.0.64" -smol = "1.2.5" +tokio = { version = "0.2", features = ["full"] } telegram-bot = "0.8.0" -uuid = "0.8.2" +uuid = { version = "0.8.2", features = ["v4"] } +anyhow = "1.0.41" diff --git a/src/bot.rs b/src/bot.rs deleted file mode 100644 index 2df53a5..0000000 --- a/src/bot.rs +++ /dev/null @@ -1,111 +0,0 @@ -use rand::Rng; -use std::error::Error; - -use crate::{ - model::{self, SampleModel}, - publish, selection, -}; - -pub struct IzzilisBot { - model: T, - publisher: U, - selector: V, - loaded_samples: Vec, -} - -impl IzzilisBot -where - T: model::SampleModel, - U: publish::Publisher, - V: selection::Selector, -{ - pub fn new(model: T, publisher: U, selector: V) -> IzzilisBot { - Self { - model, - publisher, - loaded_samples: Vec::new(), - selector: selector, - } - } - - pub fn generate_samples(&mut self) -> Result<(), Box> { - // let lines = self.model.generate_sample_lines()?; - // for line in lines { - // self.selector.send_for_review(line)?; - // } - - // self.loaded_samples = self.selector.collect_selected_samples(); // wtf happens to the original self.loaded_samples??????? - - // Ok(()) - todo!() - } - - pub fn publish(&mut self) -> Result<(), U::Error> { - if self.loaded_samples.len() < 5 { - // Refresh samples. Either none have been generated so far, - // or generated ones are stale. - // - // This is a shit solution, but I'm going with it for v1 - // purely because I don't know the language well enough to be - // confident in doing this via threads. Yet. - // TODO handle errors here - let _ = self.generate_samples(); - } - let sample_index = rand::thread_rng().gen_range(0..self.loaded_samples.len() - 1); - let content = self.loaded_samples[sample_index].clone(); - self.loaded_samples.remove(sample_index); - - self.publisher.publish(content) - } -} - -#[cfg(tests)] -mod tests { - use std::io::{self, ErrorKind}; - - use crate::{generator, model, publish}; - - struct fake_sampler { - should_ok: bool, - ok_str: String, - } - - struct fake_publisher { - should_ok: bool, - } - - impl model::SampleModel for fake_sampler { - fn get_sample(&self) -> Result { - if self.should_ok { - return Ok(self.ok_str.clone()); - } - Err(io::Error::new(ErrorKind::NotFound, "error")) - } - } - - impl publish::Publisher for fake_publisher { - fn publish(&self, content: String) -> Option> { - if self.should_ok { - return None; - } - Some(Box::new(io::Error::new(ErrorKind::NotFound, "error"))) - } - } - - #[test] - fn generate_samples_populates() { - let model_ok_string = String::from("model_ok"); - let model = fake_sampler { - should_ok: true, - ok_str: model_ok_string, - }; - let gen = generator::Generator::new(model); - let publish = fake_publisher { should_ok: true }; - let mut bot = super::IzzilisBot::new(gen, publish); - - match bot.publish() { - Some(_) => panic!("publish failed"), - None => (), - } - } -} diff --git a/src/config.rs b/src/config.rs index 214d4bb..cfd09f1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,36 +1,29 @@ -use std::error::Error; +use std::{error::Error, path::Path}; use serde::{Deserialize, Serialize}; +use telegram_bot::ChatId; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct Config { - python_path: String, - model_name: String, - temperature: String, - top_k: String, - gpt_code_path: String, - fediverse_base_url: String, - interval_seconds: MinMax, + pub python_path: String, + pub model_name: String, + pub temperature: String, + pub top_k: String, + pub gpt_code_path: String, + pub fediverse_base_url: String, + pub interval_seconds: MinMax, + pub bot_token: String, + pub chat_ref: ChatId, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct MinMax { - min: u64, - max: u64, + pub min: u64, + pub max: u64, } -impl MinMax { - pub fn min(&self) -> u64 { - self.min - } - - pub fn max(&self) -> u64 { - self.max - } -} - -impl Config { - pub fn default() -> Config { +impl Default for Config { + fn default() -> Self { Config { python_path: String::from("/usr/bin/python3"), model_name: String::from("117M"), @@ -42,60 +35,23 @@ impl Config { min: 60 * 30, max: 60 * 90, }, + bot_token: "".to_owned(), + chat_ref: ChatId::new(0), } } - - pub fn from(path: String) -> Result> { - let file_bytes = std::fs::read(path)?; - match serde_json::from_slice(&file_bytes) { - Ok(res) => Ok(res), - Err(err) => Err(Box::new(err)), - } - } - - pub fn save(&self, path: String) -> Option> { - let cfg_json = match serde_json::to_vec(self) { - Ok(res) => res, - Err(err) => return Some(Box::new(err)), - }; - - match std::fs::write(path, &cfg_json) { - Ok(_) => None, - Err(err) => Some(Box::new(err)), - } - } - - /// Get a reference to the config's python path. - pub fn python_path(&self) -> String { - self.python_path.clone() - } - - /// Get a reference to the config's model name. - pub fn model_name(&self) -> String { - self.model_name.clone() - } - - /// Get a reference to the config's temperature. - pub fn temperature(&self) -> String { - self.temperature.clone() - } - - /// Get a reference to the config's top k. - pub fn top_k(&self) -> String { - self.top_k.clone() - } - - /// Get a reference to the config's gpt code path. - pub fn gpt_code_path(&self) -> String { - self.gpt_code_path.clone() - } - - pub fn fediverse_base_url(&self) -> String { - self.fediverse_base_url.clone() - } - - /// Get a reference to the config's inverval seconds. - pub fn interval_seconds(&self) -> &MinMax { - &self.interval_seconds - } +} + +impl Config { + pub fn from>(path: P) -> Result> { + let file_bytes = std::fs::read(path)?; + + Ok(serde_json::from_slice(&file_bytes)?) + } + + pub fn save>(&self, path: P) -> Result<(), Box> { + let cfg_json = serde_json::to_vec(self)?; + std::fs::write(path, &cfg_json)?; + + Ok(()) + } } diff --git a/src/main.rs b/src/main.rs index 38ad79e..7ce2eff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,86 +1,88 @@ -use std::{error::Error, process, time::Duration}; +use std::{convert::Infallible, error::Error, sync::Arc}; -use chrono::Local; -use rand::Rng; +use telegram_bot::{Api, ChatRef}; -use crate::{ - bot::IzzilisBot, - publish::FediversePublisher, - selection::{ConsoleSelector, SelectorExt}, -}; +use crate::selection::{telegram::get_chat_ref, SelectorExt, TelegramSelector}; -use futures::StreamExt; -use futures_timer::Delay; +use futures::{sink::unfold, StreamExt}; use model::{EmptyModel, SampleModelExt}; -mod bot; mod config; mod model; -mod publish; mod selection; const CONFIG_PATH: &str = "bot_config.json"; -fn main() -> Result<(), Box> { - smol::block_on(async { - let cfg = match config::Config::from(CONFIG_PATH.to_string()) { - Ok(cfg) => cfg, - Err(_) => { - println!( - "Failed reading config at [{}], writing default", - CONFIG_PATH - ); - match config::Config::default().save(CONFIG_PATH.to_string()) { - Some(err) => println!("Failed writing file to {}: {}", CONFIG_PATH, err), - None => (), - } - process::exit(1); - } - }; - - // let mut gpt_model = ConsoleSelector.filter( - // model::GPTSampleModel::new( - // cfg.python_path(), - // cfg.gpt_code_path(), - // vec![ - // "generate_unconditional_samples.py".to_string(), - // "--model_name".to_string(), - // cfg.model_name(), - // "--temperature".to_string(), - // cfg.temperature(), - // "--top_k".to_string(), - // cfg.top_k(), - // "--nsamples".to_string(), - // "1".to_string(), - // ], - // ) - // .into_stream() - // .take(10), - // ); - let mut gpt_model = ConsoleSelector.filter(EmptyModel.into_stream().take(5)); - - while let Some(Ok(sample)) = gpt_model.next().await { - println!("{}", sample); +#[tokio::main] +async fn main() -> Result<(), Box> { + let mut cfg = match config::Config::from(CONFIG_PATH) { + Ok(cfg) => Ok(cfg), + Err(_) => { + println!( + "Failed reading config at [{}], writing default", + CONFIG_PATH + ); + config::Config::default().save(CONFIG_PATH)?; + Err(anyhow::anyhow!("Invalid configuration")) } + }?; - return Ok(()); - // let publisher = FediversePublisher::new(cfg.fediverse_base_url())?; - // // let publisher = ConsolePublisher::new(); - // let console_selector = ConsoleSelector::new(); - // let mut bot = IzzilisBot::new(gen, publisher, console_selector); - // bot.generate_samples(); + // let mut gpt_model = ConsoleSelector.filter( + // model::GPTSampleModel::new( + // cfg.python_path(), + // cfg.gpt_code_path(), + // vec![ + // "generate_unconditional_samples.py".to_string(), + // "--model_name".to_string(), + // cfg.model_name(), + // "--temperature".to_string(), + // cfg.temperature(), + // "--top_k".to_string(), + // cfg.top_k(), + // "--nsamples".to_string(), + // "1".to_string(), + // ], + // ) + // .into_stream() + // .take(10), + // ); - // let cfg_interval = cfg.interval_seconds(); - // loop { - // let wait_seconds = rand::thread_rng().gen_range(cfg_interval.min()..cfg_interval.max()); - // let wait_time = Duration::from_secs(wait_seconds); - // let now = Local::now(); - // println!("[{}] Next post is in [{}] seconds", now, wait_seconds); - // Delay::new(wait_time).await; - // match bot.publish() { - // Err(err) => println!("Got error from publish: [{}]; continuing", err), - // Ok(()) => println!("publish() call successful"), - // } - // } - }) + let api = Arc::new(Api::new( + std::env::var("TELEGRAM_BOT_KEY").expect("bot key not present"), + )); + + let chat = get_chat_ref(api.clone(), None) + .await + .expect("could not get chat ref"); + + let chat_ref = chat.lock().await.clone(); + + if let ChatRef::Id(id) = chat_ref { + cfg.chat_ref = id; + cfg.save(CONFIG_PATH)?; + } + + let cfg_clone = cfg.clone(); + let mut gpt_model = TelegramSelector::new( + api, + chat, + Box::pin(unfold((), move |_, chat_ref| { + let mut cfg_clone = cfg_clone.clone(); + async move { + if let ChatRef::Id(id) = &chat_ref { + cfg_clone.chat_ref = id.clone(); + let _ = cfg_clone.save(CONFIG_PATH); + } + Ok::<_, Infallible>(()) + } + })), + ) + .filter(EmptyModel.into_stream()) + .take(5); + + while let Some(Ok(sample)) = gpt_model.next().await { + println!("{}", sample); + } + + return Ok(()); } diff --git a/src/publish.rs b/src/publish.rs deleted file mode 100644 index ff0c006..0000000 --- a/src/publish.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::{convert::Infallible, error::Error}; - -use elefren::{ - helpers::{cli, toml}, - scopes::Scopes, - status_builder::Visibility, - Language, Mastodon, MastodonClient, Registration, StatusBuilder, -}; - -const FEDIVERSE_TOML_PATH: &str = "fediverse.toml"; - -pub trait Publisher { - type Error; - - fn publish(&self, content: String) -> Result<(), Self::Error>; -} - -pub struct FediversePublisher { - client: Mastodon, -} - -pub struct ConsolePublisher; - -impl Publisher for ConsolePublisher { - type Error = Infallible; - - fn publish(&self, content: String) -> Result<(), Self::Error> { - println!("Publishing content to stdout: {}", content); - Ok(()) - } -} - -impl ConsolePublisher { - pub fn new() -> ConsolePublisher { - ConsolePublisher {} - } -} - -impl FediversePublisher { - pub fn new(fedi_url: String) -> Result> { - Ok(Self { - client: toml::from_file(FEDIVERSE_TOML_PATH) - .map(|data| Ok(Mastodon::from(data))) - .unwrap_or_else(|_| register(fedi_url))?, - }) - } -} - -impl Publisher for FediversePublisher { - type Error = Box; - - fn publish(&self, content: String) -> Result<(), Self::Error> { - let status = StatusBuilder::new() - .status(&content) - // .visibility(Visibility::Direct) - .visibility(Visibility::Public) - .sensitive(false) - .language(Language::Eng) - .build() - .map_err(|e| Box::new(e) as Box)?; - - println!("Posting status [{}] to fediverse", &content); - - self.client - .new_status(status) - .map_err(|e| Box::new(e) as Box)?; - - Ok(()) - } -} - -fn register(fedi_url: String) -> Result> { - let registration = Registration::new(fedi_url) - .client_name("izzilis") - .scopes(Scopes::write_all()) - .build()?; - let fediverse = cli::authenticate(registration)?; - - // Save app data for using on the next run. - toml::to_file(&*fediverse, FEDIVERSE_TOML_PATH.to_string())?; - - Ok(fediverse) -} diff --git a/src/selection/mod.rs b/src/selection/mod.rs index e298808..7890405 100644 --- a/src/selection/mod.rs +++ b/src/selection/mod.rs @@ -1,8 +1,9 @@ use futures::{stream::BoxStream, Future, Stream, TryStreamExt}; mod console; -mod telegram; +pub mod telegram; pub use console::ConsoleSelector; +pub use telegram::TelegramSelector; pub trait Selector { type Error; diff --git a/src/selection/telegram.rs b/src/selection/telegram.rs index f1786a4..0f30d93 100644 --- a/src/selection/telegram.rs +++ b/src/selection/telegram.rs @@ -1,14 +1,20 @@ -use async_std::future::pending; -use futures::{channel::oneshot::Sender, future::BoxFuture, lock::Mutex, Sink, StreamExt}; -use std::{collections::HashMap, sync::Arc}; +use futures::{ + channel::oneshot::{self, Sender}, + future::BoxFuture, + lock::Mutex, + Sink, SinkExt, StreamExt, +}; +use std::{collections::HashMap, fmt::Debug, sync::Arc}; use telegram_bot::{ - self, requests, Api, CallbackQuery, ChatRef, InlineKeyboardButton, InlineKeyboardMarkup, - SendMessage, Update, UpdateKind, + self, types::requests::answer_callback_query::CanAnswerCallbackQuery, Api, ChatRef, + DeleteMessage, EditMessageReplyMarkup, InlineKeyboardButton, InlineKeyboardMarkup, Message, + MessageKind, SendMessage, ToMessageId, ToSourceChat, Update, UpdateKind, }; use uuid::Uuid; use super::Selector; +#[derive(Clone)] pub struct TelegramSelector { client: Arc, chat_ref: Arc>, @@ -16,32 +22,86 @@ pub struct TelegramSelector { } impl TelegramSelector { - pub fn new(api: Arc, chat_ref: Arc>) -> Self { + pub fn new + Send + Unpin + 'static>( + api: Arc, + chat_ref: Arc>, + mut updates: S, + ) -> Self + where + S::Error: Debug, + { let pending: Arc>>> = Arc::new(Mutex::new(HashMap::new())); - let stream = api.clone(); + let api_clone = api.clone(); + let chat_ref_clone = chat_ref.clone(); let pending_clone = pending.clone(); - smol::spawn(async move { - let stream = stream.stream(); - while let Some(Ok(Update { - kind: - UpdateKind::CallbackQuery(CallbackQuery { - data: Some(query), .. - }), - .. - })) = stream.next().await - { - let uuid_bytes = query.as_bytes().get(0..32); - let bool_byte = query.as_bytes().get(32); - if let (Some(uuid), Some(keep)) = (uuid_bytes, bool_byte) { - let uuid = Uuid::parse_str(&String::from_utf8_lossy(uuid)); - if let Ok(uuid) = uuid { - if let Some(sender) = pending_clone.lock().await.get(&uuid) { - sender.send(match *keep as char { - 't' => true, - _ => false, - }); + tokio::spawn(async move { + let mut stream = api_clone.stream(); + while let Some(Ok(data)) = stream.next().await { + if let Update { + kind: UpdateKind::CallbackQuery(query), + .. + } = data + { + if let Some(data) = query.data.clone() { + let uuid_bytes = data.as_bytes().get(0..32); + let bool_byte = data.as_bytes().get(32); + if let (Some(uuid), Some(keep)) = (uuid_bytes, bool_byte) { + let uuid = Uuid::parse_str(&String::from_utf8_lossy(uuid)); + if let Ok(uuid) = uuid { + if let Some(sender) = pending_clone.lock().await.remove(&uuid) { + let keep = match *keep as char { + 't' => true, + _ => false, + }; + let _ = sender.send(keep); + let _ = api_clone + .send(query.answer(if keep { + "Kept!" + } else { + "Discarded!" + })) + .await; + + if let Some(message) = query.message { + let _ = api_clone + .send(EditMessageReplyMarkup::new( + message.to_source_chat(), + message.to_message_id(), + None::, + )) + .await; + if !keep { + let _ = api_clone + .send(DeleteMessage::new( + message.to_source_chat(), + message.to_message_id(), + )) + .await; + } + } + + continue; + } + } } } + } else if let Update { + kind: + UpdateKind::Message(Message { + chat, + kind: MessageKind::Text { data, .. }, + .. + }), + .. + } = data + { + if data.starts_with("/setmain") { + let new_chat_ref = ChatRef::from_chat_id(chat.id()); + if let Err(e) = updates.send(new_chat_ref.clone()).await { + println!("failed to send updated chat ref: {:?}", e); + } + *chat_ref_clone.lock().await = new_chat_ref; + } } } }); @@ -64,10 +124,58 @@ impl Selector for TelegramSelector { let pending = self.pending.clone(); Box::pin(async move { let chat_ref = chat_ref.lock().await.clone(); - let message = SendMessage::new(chat_ref, message).reply_markup( - InlineKeyboardMarkup::new().add_row(vec![InlineKeyboardButton::callback("Keep")]), - ); + let uuid = Uuid::new_v4(); + let generate_callback = |keep| { + let mut buffer = vec![0u8; 33]; + uuid.to_simple().encode_lower(&mut buffer); + buffer[32] = if keep { 't' } else { 'f' } as u8; + String::from_utf8_lossy(&buffer).to_string() + }; + let mut message = SendMessage::new(chat_ref, message); + message.reply_markup({ + let mut kb = InlineKeyboardMarkup::new(); + kb.add_row(vec![ + InlineKeyboardButton::callback("Keep", generate_callback(true)), + InlineKeyboardButton::callback("Discard", generate_callback(false)), + ]); + kb + }); + let (sender, receiver) = oneshot::channel(); + pending.lock().await.insert(uuid, sender); + client.send(message).await?; + Ok(receiver.await.unwrap()) }) } - // /setmain +} + +pub async fn get_chat_ref( + api: Arc, + mut chat_ref: Option, +) -> Result>, telegram_bot::Error> { + let mut stream = api.stream(); + + if chat_ref.is_none() { + while let Some(data) = stream.next().await.transpose()? { + if let Update { + kind: + UpdateKind::Message(Message { + chat, + kind: MessageKind::Text { data, .. }, + .. + }), + .. + } = data + { + if data.starts_with("/setmain") { + let chat_ref_temp = ChatRef::from_chat_id(chat.id()); + chat_ref = Some(chat_ref_temp); + break; + } + } + } + } + + let chat_ref = Arc::new(Mutex::new(chat_ref.expect("bot API failed silently"))); + + Ok(chat_ref) }