99 lines
2.7 KiB
Rust
99 lines
2.7 KiB
Rust
use std::convert::Infallible;
|
|
|
|
use async_std::{io, process::Command};
|
|
use futures::{
|
|
future::{self, BoxFuture},
|
|
stream::BoxStream,
|
|
Future, Stream, TryStreamExt,
|
|
};
|
|
|
|
pub trait SampleModel {
|
|
type Error;
|
|
type Sample: Future<Output = Result<String, Self::Error>>;
|
|
|
|
fn get_sample(&self) -> Self::Sample;
|
|
}
|
|
|
|
pub struct EmptyModel;
|
|
|
|
impl SampleModel for EmptyModel {
|
|
type Error = Infallible;
|
|
type Sample = future::Ready<Result<String, Self::Error>>;
|
|
|
|
fn get_sample(&self) -> Self::Sample {
|
|
future::ready(Ok(String::from(
|
|
"The quick brown fox jumps over the lazy dog.",
|
|
)))
|
|
}
|
|
}
|
|
|
|
pub struct GPTSampleModel {
|
|
python_command: String,
|
|
command_working_path: String,
|
|
command_args: Vec<String>,
|
|
}
|
|
|
|
impl SampleModel for GPTSampleModel {
|
|
type Error = io::Error;
|
|
type Sample = BoxFuture<'static, Result<String, Self::Error>>;
|
|
|
|
fn get_sample(&self) -> Self::Sample {
|
|
let cmd = Command::new(&self.python_command)
|
|
.current_dir(&self.command_working_path)
|
|
.args(&self.command_args)
|
|
.output();
|
|
Box::pin(async { Ok(String::from_utf8_lossy(&cmd.await?.stdout).to_string()) })
|
|
}
|
|
}
|
|
|
|
impl GPTSampleModel {
|
|
pub fn new(
|
|
python_command: String,
|
|
command_working_path: String,
|
|
command_args: Vec<String>,
|
|
) -> GPTSampleModel {
|
|
Self {
|
|
python_command: python_command,
|
|
command_working_path: command_working_path,
|
|
command_args: command_args,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait SampleModelExt: SampleModel {
|
|
type Stream: Stream<Item = Result<String, Self::Error>>;
|
|
|
|
fn into_stream(self) -> Self::Stream;
|
|
}
|
|
|
|
const SAMPLE_SPLIT_WORD: &str = "<|endoftext|>";
|
|
const SAMPLE_SAMPLE_LINE: &str =
|
|
"======================================== SAMPLE 1 ========================================";
|
|
|
|
impl<T: SampleModel + Send + Sync + 'static> SampleModelExt for T
|
|
where
|
|
Self::Sample: Send,
|
|
{
|
|
type Stream = BoxStream<'static, Result<String, Self::Error>>;
|
|
|
|
fn into_stream(self) -> Self::Stream {
|
|
Box::pin(
|
|
futures::stream::try_unfold(self, |this| async {
|
|
Ok(Some((this.get_sample().await?, this)))
|
|
})
|
|
.map_ok(|samples| {
|
|
futures::stream::iter(
|
|
samples
|
|
.replace(SAMPLE_SAMPLE_LINE, "")
|
|
.split(SAMPLE_SPLIT_WORD)
|
|
.map(|elem| elem.to_owned())
|
|
.collect::<Vec<String>>()
|
|
.into_iter()
|
|
.map(|elem| Ok(elem.trim().to_owned())),
|
|
)
|
|
})
|
|
.try_flatten(),
|
|
)
|
|
}
|
|
}
|