From 2b131ad1d23bca3629dacca3cc1412982c529f50 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 28 Nov 2023 16:57:16 +0800 Subject: [PATCH] refactor: handle max output length in StopCondition (#910) * refactor: handle max output length in StopCondition * trim stop words * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- crates/llama-cpp-bindings/src/lib.rs | 24 ++++++++++++----- crates/llama-cpp-bindings/src/llama.rs | 7 +++-- crates/tabby-inference/src/decoding.rs | 36 ++++++++++++++++++++++---- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index cd2938cf58c9..0dbdaf9cbea2 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -72,8 +72,19 @@ impl LlamaTextGeneration { #[async_trait] impl TextGeneration for LlamaTextGeneration { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { + let language = options.language; let s = self.generate_stream(prompt, options).await; - helpers::stream_to_string(s).await + let text = helpers::stream_to_string(s).await; + + let Some(language) = language else { + return text; + }; + + let Some(trimmed) = self.stop_condition_factory.trim_stop_words(language, &text) else { + return text; + }; + + trimmed } async fn generate_stream( @@ -81,7 +92,11 @@ impl TextGeneration for LlamaTextGeneration { prompt: &str, options: TextGenerationOptions, ) -> BoxStream { - let stop_condition = self.stop_condition_factory.create(prompt, options.language); + let stop_condition = self.stop_condition_factory.create( + prompt, + options.max_decoding_length, + options.language, + ); let mut rx = self .service @@ -89,13 +104,8 @@ impl TextGeneration for LlamaTextGeneration { .await; let s = stream! { - let mut length = 0; while let Some(new_text) = rx.recv().await { yield new_text; - length += 1; - if length >= options.max_decoding_length { - break; - } } rx.close(); diff --git a/crates/llama-cpp-bindings/src/llama.rs b/crates/llama-cpp-bindings/src/llama.rs index 67d9a51f938e..ae2e0357197f 100644 --- a/crates/llama-cpp-bindings/src/llama.rs +++ b/crates/llama-cpp-bindings/src/llama.rs @@ -87,14 +87,13 @@ impl LlamaServiceImpl { if tx.is_closed() || text.is_empty() { // Cancelled by client side or hit eos. stopped = true; - } else if !stop_condition.should_stop(&text) { + } else { + stopped = stop_condition.should_stop(&text); + match tx.send(text).await { Ok(_) => (), Err(_) => stopped = true, } - } else { - // Stoop words stopped - stopped = true; } if stopped { diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index cbbac9532a89..d40c571929ac 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -22,11 +22,16 @@ impl Default for StopConditionFactory { } impl StopConditionFactory { - pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition { + pub fn create( + &self, + text: &str, + max_decoding_length: usize, + language: Option<&'static Language>, + ) -> StopCondition { if let Some(language) = language { - StopCondition::new(self.get_re(language), text) + StopCondition::new(self.get_re(language), max_decoding_length, text) } else { - StopCondition::new(None, text) + StopCondition::new(None, max_decoding_length, text) } } @@ -45,6 +50,22 @@ impl StopConditionFactory { re.map(|x| x.value().clone()) } } + + pub fn trim_stop_words(&self, language: &'static Language, text: &str) -> Option { + let Some(re) = self.get_re(language) else { + return None; + }; + + let text = reverse(text); + + let text = if let Some(m) = re.find_at(&text, 0) { + &text[m.end()..] + } else { + &text + }; + + Some(reverse(text)) + } } fn create_stop_regex(stop_words: Vec) -> Regex { @@ -60,14 +81,18 @@ fn create_stop_regex(stop_words: Vec) -> Regex { pub struct StopCondition { stop_re: Option, + max_decoding_length: usize, reversed_text: String, + num_decoded: usize, } impl StopCondition { - pub fn new(stop_re: Option, text: &str) -> Self { + pub fn new(stop_re: Option, max_decoding_length: usize, text: &str) -> Self { Self { stop_re, + max_decoding_length, reversed_text: reverse(text), + num_decoded: 0, } } @@ -82,7 +107,8 @@ impl StopCondition { } } - false + self.num_decoded += 1; + self.num_decoded >= self.max_decoding_length } }