New selector API, console implementation

This commit is contained in:
Izzy Swart 2021-07-04 17:43:11 -07:00
parent f513f1d591
commit 61cd40269a
6 changed files with 132 additions and 518 deletions

432
Cargo.lock generated
View File

@ -15,21 +15,6 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "0.7.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
dependencies = [
"memchr",
]
[[package]]
name = "ascii"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97be891acc47ca214468e09425d02cef3af2c94d0d82081cd02061f996802f14"
[[package]]
name = "async-channel"
version = "1.6.1"
@ -91,7 +76,7 @@ dependencies = [
"concurrent-queue",
"futures-lite",
"libc",
"log 0.4.14",
"log",
"once_cell",
"parking",
"polling",
@ -165,7 +150,7 @@ dependencies = [
"futures-lite",
"gloo-timers",
"kv-log-macro",
"log 0.4.14",
"log",
"memchr",
"num_cpus",
"once_cell",
@ -239,12 +224,6 @@ version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7"
[[package]]
name = "base64"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bitflags"
version = "1.2.1"
@ -286,16 +265,6 @@ dependencies = [
"once_cell",
]
[[package]]
name = "buf_redux"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b953a6887648bb07a535631f2bc00fbdb2a2216f135552cb3f534ed136b9c07f"
dependencies = [
"memchr",
"safemem",
]
[[package]]
name = "bumpalo"
version = "3.7.0"
@ -403,18 +372,6 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "chunked_transfer"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498d20a7aaf62625b9bf26e637cf7736417cde1d0c99f1d04d1170229a85cf87"
[[package]]
name = "chunked_transfer"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e"
[[package]]
name = "cloudabi"
version = "0.0.3"
@ -452,7 +409,7 @@ dependencies = [
"cookie",
"failure",
"idna 0.1.5",
"log 0.4.14",
"log",
"publicsuffix",
"serde",
"serde_json",
@ -597,7 +554,7 @@ dependencies = [
"doc-comment",
"hyper-old-types",
"isolang",
"log 0.4.14",
"log",
"reqwest",
"serde",
"serde_derive",
@ -628,7 +585,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc"
dependencies = [
"backtrace",
"version_check 0.9.3",
"version_check",
]
[[package]]
@ -717,19 +674,6 @@ dependencies = [
"percent-encoding 2.1.0",
]
[[package]]
name = "frankenstein"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92ecce8fc4d7ca2bf5c34f81342d527ee0c47de27c13c8b57df7f5431589e156"
dependencies = [
"mime_guess 2.0.3",
"multipart",
"serde",
"serde_json",
"ureq",
]
[[package]]
name = "fuchsia-cprng"
version = "0.1.1"
@ -939,12 +883,6 @@ dependencies = [
"web-sys",
]
[[package]]
name = "groupable"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32619942b8be646939eaf3db0602b39f5229b74575b67efc897811ded1db4e57"
[[package]]
name = "h2"
version = "0.1.26"
@ -957,7 +895,7 @@ dependencies = [
"futures 0.1.31",
"http 0.1.21",
"indexmap",
"log 0.4.14",
"log",
"slab",
"string",
"tokio-io",
@ -1018,25 +956,6 @@ version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3a87b616e37e93c22fb19bcd386f02f3af5ea98a25670ad0fce773de23c5e68"
[[package]]
name = "hyper"
version = "0.10.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a0652d9a2609a968c14be1a9ea00bf4b1d64e2e1f53a1b51b6fff3a6e829273"
dependencies = [
"base64 0.9.3",
"httparse",
"language-tags",
"log 0.3.9",
"mime 0.2.6",
"num_cpus",
"time",
"traitobject",
"typeable",
"unicase 1.4.2",
"url 1.7.2",
]
[[package]]
name = "hyper"
version = "0.12.36"
@ -1052,7 +971,7 @@ dependencies = [
"httparse",
"iovec",
"itoa",
"log 0.4.14",
"log",
"net2",
"rustc_version",
"time",
@ -1077,11 +996,11 @@ dependencies = [
"bytes 0.4.12",
"httparse",
"language-tags",
"log 0.4.14",
"mime 0.3.16",
"log",
"mime",
"percent-encoding 1.0.1",
"time",
"unicase 2.6.0",
"unicase",
]
[[package]]
@ -1092,7 +1011,7 @@ checksum = "3a800d6aa50af4b5850b2b0f659625ce9504df908e9733b635720483be26174f"
dependencies = [
"bytes 0.4.12",
"futures 0.1.31",
"hyper 0.12.36",
"hyper",
"native-tls",
"tokio-io",
]
@ -1156,22 +1075,6 @@ dependencies = [
"libc",
]
[[package]]
name = "iron"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6d308ca2d884650a8bf9ed2ff4cb13fbb2207b71f64cda11dc9b892067295e8"
dependencies = [
"hyper 0.10.16",
"log 0.3.9",
"mime_guess 1.8.8",
"modifier",
"num_cpus",
"plugin",
"typemap",
"url 1.7.2",
]
[[package]]
name = "isolang"
version = "1.0.0"
@ -1196,7 +1099,6 @@ dependencies = [
"async-std",
"chrono",
"elefren",
"frankenstein",
"futures 0.3.15",
"futures-timer",
"rand 0.8.4",
@ -1230,7 +1132,7 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f"
dependencies = [
"log 0.4.14",
"log",
]
[[package]]
@ -1260,15 +1162,6 @@ dependencies = [
"scopeguard",
]
[[package]]
name = "log"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e19e8d5c34a3e0e2223db8e060f9e8264aeeb5c5fc64a4ee9965c062211c024b"
dependencies = [
"log 0.4.14",
]
[[package]]
name = "log"
version = "0.4.14"
@ -1306,41 +1199,20 @@ dependencies = [
"autocfg 1.0.1",
]
[[package]]
name = "mime"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba626b8a6de5da682e1caa06bdb42a335aee5a84db8e5046a3e8ab17ba0a3ae0"
dependencies = [
"log 0.3.9",
]
[[package]]
name = "mime"
version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "1.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "216929a5ee4dd316b1702eedf5e74548c123d370f47841ceaac38ca154690ca3"
dependencies = [
"mime 0.2.6",
"phf",
"phf_codegen",
"unicase 1.4.2",
]
[[package]]
name = "mime_guess"
version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2684d4c2e97d99848d30b324b00c8fcc7e5c897b7cbb5819b09e7c90e8baf212"
dependencies = [
"mime 0.3.16",
"unicase 2.6.0",
"mime",
"unicase",
]
[[package]]
@ -1365,7 +1237,7 @@ dependencies = [
"iovec",
"kernel32-sys",
"libc",
"log 0.4.14",
"log",
"miow",
"net2",
"slab",
@ -1384,44 +1256,6 @@ dependencies = [
"ws2_32-sys",
]
[[package]]
name = "modifier"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41f5c9112cb662acd3b204077e0de5bc66305fa8df65c8019d5adb10e9ab6e58"
[[package]]
name = "multipart"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182"
dependencies = [
"buf_redux",
"httparse",
"hyper 0.10.16",
"iron",
"log 0.4.14",
"mime 0.3.16",
"mime_guess 2.0.3",
"nickel",
"quick-error",
"rand 0.8.4",
"safemem",
"tempfile",
"tiny_http",
"twoway",
]
[[package]]
name = "mustache"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51956ef1c5d20a1384524d91e616fb44dfc7d8f249bf696d49c97dd3289ecab5"
dependencies = [
"log 0.3.9",
"serde",
]
[[package]]
name = "native-tls"
version = "0.2.7"
@ -1430,7 +1264,7 @@ checksum = "b8d96b2e1c8da3957d58100b09f102c6d9cfdfced01b7ec5a8974044bb09dbd4"
dependencies = [
"lazy_static",
"libc",
"log 0.4.14",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
@ -1451,27 +1285,6 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "nickel"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5061a832728db2dacb61cefe0ce303b58f85764ec680e71d9138229640a46d9"
dependencies = [
"groupable",
"hyper 0.10.16",
"lazy_static",
"log 0.3.9",
"modifier",
"mustache",
"plugin",
"regex",
"serde",
"serde_json",
"time",
"typemap",
"url 1.7.2",
]
[[package]]
name = "num-integer"
version = "0.1.44"
@ -1644,7 +1457,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234f71a15de2288bcb7e3b6515828d22af7ec8598ee6d24c3b526fa0a80b67a0"
dependencies = [
"siphasher",
"unicase 1.4.2",
]
[[package]]
@ -1665,15 +1477,6 @@ version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c"
[[package]]
name = "plugin"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a6a0dc3910bc8db877ffed8e457763b317cf880df4ae19109b9f77d277cf6e0"
dependencies = [
"typemap",
]
[[package]]
name = "polling"
version = "2.1.0"
@ -1682,7 +1485,7 @@ checksum = "92341d779fa34ea8437ef4d82d440d5e1ce3f3ff7f824aa64424cd481f9a1f25"
dependencies = [
"cfg-if 1.0.0",
"libc",
"log 0.4.14",
"log",
"wepoll-ffi",
"winapi 0.3.9",
]
@ -1732,15 +1535,9 @@ checksum = "ffade02495f22453cd593159ea2f59827aae7f53fa8323f756799b670881dcf8"
dependencies = [
"bitflags",
"memchr",
"unicase 2.6.0",
"unicase",
]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quote"
version = "1.0.9"
@ -1961,23 +1758,6 @@ dependencies = [
"bitflags",
]
[[package]]
name = "regex"
version = "1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "remove_dir_all"
version = "0.5.3"
@ -2001,11 +1781,11 @@ dependencies = [
"flate2",
"futures 0.1.31",
"http 0.1.21",
"hyper 0.12.36",
"hyper",
"hyper-tls",
"log 0.4.14",
"mime 0.3.16",
"mime_guess 2.0.3",
"log",
"mime",
"mime_guess",
"native-tls",
"serde",
"serde_json",
@ -2021,21 +1801,6 @@ dependencies = [
"winreg",
]
[[package]]
name = "ring"
version = "0.16.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc"
dependencies = [
"cc",
"libc",
"once_cell",
"spin",
"untrusted",
"web-sys",
"winapi 0.3.9",
]
[[package]]
name = "rustc-demangle"
version = "0.1.20"
@ -2051,19 +1816,6 @@ dependencies = [
"semver 0.9.0",
]
[[package]]
name = "rustls"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7"
dependencies = [
"base64 0.13.0",
"log 0.4.14",
"ring",
"sct",
"webpki",
]
[[package]]
name = "ryu"
version = "1.0.5"
@ -2101,16 +1853,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "sct"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "security-framework"
version = "2.3.1"
@ -2330,12 +2072,6 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "spin"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "string"
version = "0.2.1"
@ -2398,19 +2134,6 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "tiny_http"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e22cb179b63e5fc2d0b5be237dc107da072e2407809ac70a8ce85b93fe8f562"
dependencies = [
"ascii",
"chrono",
"chunked_transfer 0.3.1",
"log 0.4.14",
"url 1.7.2",
]
[[package]]
name = "tinyvec"
version = "1.2.0"
@ -2484,7 +2207,7 @@ checksum = "57fc868aae093479e3131e3d165c93b1c7474109d13c90ec0dda2a1bbfff0674"
dependencies = [
"bytes 0.4.12",
"futures 0.1.31",
"log 0.4.14",
"log",
]
[[package]]
@ -2496,7 +2219,7 @@ dependencies = [
"crossbeam-utils 0.7.2",
"futures 0.1.31",
"lazy_static",
"log 0.4.14",
"log",
"mio",
"num_cpus",
"parking_lot",
@ -2541,7 +2264,7 @@ dependencies = [
"crossbeam-utils 0.7.2",
"futures 0.1.31",
"lazy_static",
"log 0.4.14",
"log",
"num_cpus",
"slab",
"tokio-executor",
@ -2568,12 +2291,6 @@ dependencies = [
"serde",
]
[[package]]
name = "traitobject"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079"
[[package]]
name = "try-lock"
version = "0.2.3"
@ -2601,7 +2318,7 @@ dependencies = [
"http 0.2.4",
"httparse",
"input_buffer",
"log 0.4.14",
"log",
"native-tls",
"rand 0.7.3",
"sha-1",
@ -2609,30 +2326,6 @@ dependencies = [
"utf-8",
]
[[package]]
name = "twoway"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1"
dependencies = [
"memchr",
]
[[package]]
name = "typeable"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1410f6f91f21d1612654e7cc69193b0334f909dcf2c790c4826254fbb86f8887"
[[package]]
name = "typemap"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "653be63c80a3296da5551e1bfd2cca35227e13cdd08c6668903ae2f4f77aa1f6"
dependencies = [
"unsafe-any",
]
[[package]]
name = "typenum"
version = "1.13.0"
@ -2645,22 +2338,13 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c"
[[package]]
name = "unicase"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f4765f83163b74f957c797ad9253caf97f103fb064d3999aea9568d09fc8a33"
dependencies = [
"version_check 0.1.5",
]
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check 0.9.3",
"version_check",
]
[[package]]
@ -2687,37 +2371,6 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "unsafe-any"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f30360d7979f5e9c6e6cea48af192ea8fab4afb3cf72597154b8f08935bc9c7f"
dependencies = [
"traitobject",
]
[[package]]
name = "untrusted"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "ureq"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2475a6781e9bc546e7b64f4013d2f4032c8c6a40fcffd7c6f4ee734a890972ab"
dependencies = [
"base64 0.13.0",
"chunked_transfer 1.4.0",
"log 0.4.14",
"once_cell",
"rustls",
"url 2.2.2",
"webpki",
"webpki-roots",
]
[[package]]
name = "url"
version = "1.7.2"
@ -2763,7 +2416,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd320e1520f94261153e96f7534476ad869c14022aee1e59af7c778075d840ae"
dependencies = [
"ctor",
"version_check 0.9.3",
"version_check",
]
[[package]]
@ -2772,12 +2425,6 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
[[package]]
name = "version_check"
version = "0.9.3"
@ -2808,7 +2455,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6395efa4784b027708f7451087e647ec73cc74f5d9bc2e418404248d679a230"
dependencies = [
"futures 0.1.31",
"log 0.4.14",
"log",
"try-lock",
]
@ -2842,7 +2489,7 @@ checksum = "3b33f6a0694ccfea53d94db8b2ed1c3a8a4c86dd936b13b9f0a15ec4a451b900"
dependencies = [
"bumpalo",
"lazy_static",
"log 0.4.14",
"log",
"proc-macro2",
"quote",
"syn",
@ -2900,25 +2547,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "webpki"
version = "0.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "webpki-roots"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940"
dependencies = [
"webpki",
]
[[package]]
name = "wepoll-ffi"
version = "0.1.2"

View File

@ -10,7 +10,6 @@ edition = "2018"
async-std = { version = "1.9.0", features = ["unstable"] }
chrono = "0.4.19"
elefren = { version = "0.22.0", features = ["toml"] }
frankenstein = "0.4.0"
futures = "0.3.15"
futures-timer = "3.0.2"
rand = "0.8.4"

View File

@ -3,7 +3,11 @@ use std::{error::Error, process, time::Duration};
use chrono::Local;
use rand::Rng;
use crate::{bot::IzzilisBot, publish::FediversePublisher, selection::ConsoleSelector};
use crate::{
bot::IzzilisBot,
publish::FediversePublisher,
selection::{ConsoleSelector, SelectorExt},
};
use futures::StreamExt;
use futures_timer::Delay;
@ -34,23 +38,25 @@ fn main() -> Result<(), Box<dyn Error>> {
}
};
let mut gpt_model = model::GPTSampleModel::new(
cfg.python_path(),
cfg.gpt_code_path(),
vec![
"generate_unconditional_samples.py".to_string(),
"--model_name".to_string(),
cfg.model_name(),
"--temperature".to_string(),
cfg.temperature(),
"--top_k".to_string(),
cfg.top_k(),
"--nsamples".to_string(),
"1".to_string(),
],
)
.into_stream()
.take(60);
let mut gpt_model = ConsoleSelector.filter(
model::GPTSampleModel::new(
cfg.python_path(),
cfg.gpt_code_path(),
vec![
"generate_unconditional_samples.py".to_string(),
"--model_name".to_string(),
cfg.model_name(),
"--temperature".to_string(),
cfg.temperature(),
"--top_k".to_string(),
cfg.top_k(),
"--nsamples".to_string(),
"1".to_string(),
],
)
.into_stream()
.take(10),
);
while let Some(Ok(sample)) = gpt_model.next().await {
println!("{}", sample);

View File

@ -1,97 +0,0 @@
use frankenstein::{
Api, GetUpdatesParams, KeyboardButton, ReplyKeyboardMarkup, ReplyMarkup, TelegramApi,
};
use futures::Future;
use std::{error::Error, io, thread::JoinHandle};
pub trait Selector {
fn send_for_review(&mut self, message: String) -> Result<(), Box<dyn Error>>;
fn collect_selected_samples(&mut self) -> Vec<String>;
}
// pub trait Selector {
// type Error;
// type Response: Future<Output = Result<bool, Self::Error>>;
// fn review(self, data: String) -> Self::Response;
// }
pub struct TelegramSelector {
client: frankenstein::Api,
dest_chat_id: String,
listener_handle: Option<JoinHandle<()>>,
}
pub struct ConsoleSelector {
selected_samples: Vec<String>,
}
impl Selector for ConsoleSelector {
fn send_for_review(&mut self, message: String) -> Result<(), Box<dyn Error>> {
println!("generated sample [y+enter to accept]: {}", &message);
let mut choice = String::new();
io::stdin().read_line(&mut choice).expect("cum");
if choice.to_lowercase().contains("y") {
println!("accepted");
self.selected_samples.push(message);
}
Ok(())
}
fn collect_selected_samples(&mut self) -> Vec<String> {
let cloned_samples = self.selected_samples.to_owned();
self.selected_samples = Vec::new();
cloned_samples
}
}
impl ConsoleSelector {
pub fn new() -> ConsoleSelector {
Self {
selected_samples: Vec::new(),
}
}
}
const KEEP_BUTTON: &str = "Keep";
const TOSS_BUTTON: &str = "Toss";
impl Selector for TelegramSelector {
fn send_for_review(&mut self, message: String) -> Result<(), Box<dyn Error>> {
todo!();
if !self.listener_handle.is_none() {
todo!();
}
let mut message_def = frankenstein::SendMessageParams::new(
frankenstein::ChatId::String(self.dest_chat_id.clone()),
message,
);
message_def.reply_markup = Some(ReplyMarkup::ReplyKeyboardMarkup(
ReplyKeyboardMarkup::new(vec![
KeyboardButton::new(KEEP_BUTTON.to_string()),
KeyboardButton::new(TOSS_BUTTON.to_string()),
]),
));
self.client
.send_message(&message_def)
.expect("TODO handle this properly (doesn't implement std error for some reason)");
Ok(())
}
fn collect_selected_samples(&mut self) -> Vec<String> {
todo!()
}
}
impl TelegramSelector {
pub fn new(token: String, dest_chat_id: String) -> TelegramSelector {
let api = Api::new(&token);
Self {
client: api,
dest_chat_id: dest_chat_id,
listener_handle: None,
}
}
}

29
src/selection/console.rs Normal file
View File

@ -0,0 +1,29 @@
use std::error::Error;
use async_std::io::stdin;
use futures::future::BoxFuture;
use super::Selector;
#[derive(Debug, Copy, Clone)]
pub struct ConsoleSelector;
impl Selector for ConsoleSelector {
type Error = Box<dyn Error>;
type Response = BoxFuture<'static, Result<bool, Self::Error>>;
fn review(&self, message: String) -> Self::Response {
println!("{} (y/N) ", message);
let stdin = stdin();
Box::pin(async move {
let mut buffer = String::new();
stdin.read_line(&mut buffer).await?;
Ok(
match buffer.chars().next().unwrap_or('n').to_ascii_lowercase() {
'y' => true,
_ => false,
},
)
})
}
}

49
src/selection/mod.rs Normal file
View File

@ -0,0 +1,49 @@
use futures::{stream::BoxStream, Future, Stream, TryStreamExt};
mod console;
pub use console::ConsoleSelector;
pub trait Selector {
type Error;
type Response: Future<Output = Result<bool, Self::Error>>;
fn review(&self, data: String) -> Self::Response;
}
pub trait SelectorExt<E, S: Stream<Item = Result<String, E>>>: Selector {
type Stream: Stream<Item = Result<String, FilterError<E, Self::Error>>>;
fn filter(self, stream: S) -> Self::Stream;
}
pub enum FilterError<T, U> {
Model(T),
Filter(U),
}
impl<
E: 'static,
T: Selector + Sync + Send + Clone + 'static,
S: Send + Stream<Item = Result<String, E>> + 'static,
> SelectorExt<E, S> for T
where
T::Response: Send,
{
type Stream = BoxStream<'static, Result<String, FilterError<E, Self::Error>>>;
fn filter(self, stream: S) -> Self::Stream {
Box::pin(
stream
.map_err(FilterError::Model)
.try_filter_map(move |item| {
let this = self.clone();
async move {
this.review(item.clone())
.await
.map_err(FilterError::Filter)
.map(move |keep| if keep { Some(item) } else { None })
}
}),
)
}
}