izzilis/src/model.rs

99 lines
2.7 KiB
Rust

use std::convert::Infallible;
use async_std::{io, process::Command};
use futures::{
future::{self, BoxFuture},
stream::BoxStream,
Future, Stream, TryStreamExt,
};
pub trait SampleModel {
type Error;
type Sample: Future<Output = Result<String, Self::Error>>;
fn get_sample(&self) -> Self::Sample;
}
pub struct EmptyModel;
impl SampleModel for EmptyModel {
type Error = Infallible;
type Sample = future::Ready<Result<String, Self::Error>>;
fn get_sample(&self) -> Self::Sample {
future::ready(Ok(String::from(
"The quick brown fox jumps over the lazy dog.",
)))
}
}
pub struct GPTSampleModel {
python_command: String,
command_working_path: String,
command_args: Vec<String>,
}
impl SampleModel for GPTSampleModel {
type Error = io::Error;
type Sample = BoxFuture<'static, Result<String, Self::Error>>;
fn get_sample(&self) -> Self::Sample {
let cmd = Command::new(&self.python_command)
.current_dir(&self.command_working_path)
.args(&self.command_args)
.output();
Box::pin(async { Ok(String::from_utf8_lossy(&cmd.await?.stdout).to_string()) })
}
}
impl GPTSampleModel {
pub fn new(
python_command: String,
command_working_path: String,
command_args: Vec<String>,
) -> GPTSampleModel {
Self {
python_command: python_command,
command_working_path: command_working_path,
command_args: command_args,
}
}
}
pub trait SampleModelExt: SampleModel {
type Stream: Stream<Item = Result<String, Self::Error>>;
fn into_stream(self) -> Self::Stream;
}
const SAMPLE_SPLIT_WORD: &str = "<|endoftext|>";
const SAMPLE_SAMPLE_LINE: &str =
"======================================== SAMPLE 1 ========================================";
impl<T: SampleModel + Send + Sync + 'static> SampleModelExt for T
where
Self::Sample: Send,
{
type Stream = BoxStream<'static, Result<String, Self::Error>>;
fn into_stream(self) -> Self::Stream {
Box::pin(
futures::stream::try_unfold(self, |this| async {
Ok(Some((this.get_sample().await?, this)))
})
.map_ok(|samples| {
futures::stream::iter(
samples
.replace(SAMPLE_SAMPLE_LINE, "")
.split(SAMPLE_SPLIT_WORD)
.map(|elem| elem.to_owned())
.collect::<Vec<String>>()
.into_iter()
.map(|elem| Ok(elem.trim().to_owned())),
)
})
.try_flatten(),
)
}
}