From 160993be9bda052e48d12ddf32df1829402f4acd Mon Sep 17 00:00:00 2001 From: Philpax Date: Fri, 7 Apr 2023 19:01:24 +0200 Subject: [PATCH 1/2] fix(llama): buffer tokens until valid UTF-8 --- llama-cli/src/cli_args.rs | 3 - llama-rs/src/convert.rs | 26 +++++--- llama-rs/src/lib.rs | 123 ++++++++++++++++++++++++++------------ 3 files changed, 101 insertions(+), 51 deletions(-) diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index e4b54d35..5dd663db 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -268,9 +268,6 @@ impl ModelLoad { LoadProgress::HyperparametersLoaded(hparams) => { log::debug!("Loaded hyperparameters {hparams:#?}") } - LoadProgress::BadToken { index } => { - log::info!("Warning: Bad token in vocab at index {index}") - } LoadProgress::ContextSize { bytes } => log::info!( "ggml ctx size = {:.2} MB\n", bytes as f64 / (1024.0 * 1024.0) diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 285cf3c0..f83346c0 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -49,11 +49,12 @@ fn load_vocabulary(path: &Path) -> Vocabulary { let mut token_to_id = HashMap::new(); let mut max_token_length = 0; + // TODO: Does the original model use valid UTF-8 for its tokens? This seems a little suspect to me. for (idx, piece) in proto.get_pieces().iter().enumerate() { - let word = piece.get_piece().to_string(); + let word = piece.get_piece().as_bytes(); max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, idx as i32); + id_to_token.push(word.to_owned()); + token_to_id.insert(word.to_owned(), idx as i32); id_to_token_score.push(piece.get_score()); } Vocabulary { @@ -128,13 +129,20 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String fn write_tokens(file: &mut File, vocab: &Vocabulary) -> Result<(), String> { let mut values: Vec = vec![]; for (i, token) in vocab.id_to_token.iter().enumerate() { - let text = match token { - _ if token.contains("") => " \u{2047} ".as_bytes().to_vec(), - _ if token.contains("s>") => vec![], - _ if token.len() == 6 && token.contains("<0x") => { - vec![u8::from_str_radix(&token[3..5], 16).unwrap()] + // TODO: Not sure what the behaviour should be if the token is not valid UTF-8. + // + // Switching to the HF tokenizer should fix this. + let text = if let Ok(token) = std::str::from_utf8(token) { + match token { + _ if token.contains("") => " \u{2047} ".as_bytes().to_vec(), + _ if token.contains("s>") => vec![], + _ if token.len() == 6 && token.contains("<0x") => { + vec![u8::from_str_radix(&token[3..5], 16).unwrap()] + } + _ => token.replace('\u{2581}', " ").as_bytes().to_vec(), } - _ => token.replace('\u{2581}', " ").as_bytes().to_vec(), + } else { + token.clone() }; values.extend((text.len() as i32).to_le_bytes()); values.extend(&text); diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 14553379..cfddad6d 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -289,7 +289,7 @@ impl Display for InferenceStats { } type TokenId = i32; -type Token = String; +type Token = Vec; type TokenScore = f32; /// The vocabulary used by a model. @@ -309,7 +309,7 @@ pub struct Vocabulary { max_token_length: usize, } impl Vocabulary { - fn token(&self, idx: usize) -> &str { + fn token(&self, idx: usize) -> &[u8] { &self.id_to_token[idx] } } @@ -386,14 +386,6 @@ impl std::fmt::Display for TokenBias { pub enum LoadProgress<'a> { /// The hyperparameters have been loaded from the model. HyperparametersLoaded(&'a Hyperparameters), - /// A bad token was encountered during the loading process. - /// - /// This can be ignored, but invalid tokens will be replaced with - /// the `�` character. - BadToken { - /// The index within the vocabulary. - index: usize, - }, /// The context has been created. ContextSize { /// The size of the context. @@ -604,6 +596,20 @@ impl Model { Ok(bytes) } + fn read_bytes_with_len( + reader: &mut impl BufRead, + len: usize, + ) -> Result, LoadError> { + let mut bytes = vec![0u8; len]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: len, + })?; + Ok(bytes) + } + fn read_i32(reader: &mut impl BufRead) -> Result { Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) } @@ -618,15 +624,7 @@ impl Model { /// Helper function. Reads a string from the buffer and returns it. fn read_string(reader: &mut BufReader, len: usize) -> Result { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - let s = String::from_utf8(buf)?; - Ok(s) + Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) } // Verify magic @@ -682,14 +680,10 @@ impl Model { for i in 0..hparams.n_vocab { let len = read_i32(&mut reader)?; - if let Ok(word) = read_string(&mut reader, len as usize) { - max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, TokenId::try_from(i)?); - } else { - load_progress_callback(LoadProgress::BadToken { index: i }); - id_to_token.push("�".to_string()); - } + let token = read_bytes_with_len(&mut reader, len as usize)?; + max_token_length = max_token_length.max(token.len()); + id_to_token.push(token.clone()); + token_to_id.insert(token, TokenId::try_from(i)?); // Token score, currently unused if !is_legacy_model { @@ -1427,7 +1421,7 @@ impl InferenceSession { vocab: &Vocabulary, params: &InferenceParameters, prompt: &str, - callback: impl Fn(&str) -> Result<(), E>, + mut callback: impl FnMut(&[u8]) -> Result<(), E>, ) -> Result<(), InferenceError> { let beginning_of_sentence = self.n_past == 0; let prompt_tokens: Vec = vocab @@ -1464,7 +1458,7 @@ impl InferenceSession { vocab: &'v Vocabulary, params: &InferenceParameters, rng: &mut impl rand::Rng, - ) -> Result<&'v str, InferenceError> { + ) -> Result<&'v [u8], InferenceError> { if self.n_past + 1 >= model.hparams.n_ctx { return Err(InferenceError::ContextFull); } @@ -1505,15 +1499,19 @@ impl InferenceSession { prompt: &str, maximum_token_count: Option, rng: &mut impl rand::Rng, - callback: impl Fn(&str) -> Result<(), E>, + mut callback: impl FnMut(&str) -> Result<(), E>, ) -> Result { let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX); if params.play_back_previous_tokens { // "Play back" the existing tokens, so that loading from an inference snapshot works // as expected. + let mut token_utf8_buf = TokenUtf8Buffer::new(); for token_id in &self.tokens { - if let Err(e) = callback(vocab.token(*token_id as usize)) { - return Err(InferenceError::UserCallback(Box::new(e))); + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = token_utf8_buf.push(vocab.token(*token_id as usize)) { + if let Err(e) = callback(&tokens) { + return Err(InferenceError::UserCallback(Box::new(e))); + } } } } @@ -1524,7 +1522,13 @@ impl InferenceSession { // Feed the initial prompt through the transformer, to update its // context window with new data. - self.feed_prompt(model, vocab, params, prompt, |tk| callback(tk))?; + self.feed_prompt( + model, + vocab, + params, + prompt, + TokenUtf8Buffer::adapt_callback(&mut callback), + )?; stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; @@ -1533,6 +1537,7 @@ impl InferenceSession { // EndOfText token, or we run out of space in the context window, // or we reach the specified limit. let mut tokens_processed = 0; + let mut token_utf8_buf = TokenUtf8Buffer::new(); while tokens_processed < maximum_token_count { let token = match self.infer_next_token(model, vocab, params, rng) { Ok(token) => token, @@ -1540,8 +1545,11 @@ impl InferenceSession { Err(e) => return Err(e), }; - if let Err(e) = callback(token) { - return Err(InferenceError::UserCallback(Box::new(e))); + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = token_utf8_buf.push(token) { + if let Err(e) = callback(&tokens) { + return Err(InferenceError::UserCallback(Box::new(e))); + } } tokens_processed += 1; @@ -1688,7 +1696,7 @@ impl Vocabulary { &'a self, text: &str, bos: bool, - ) -> Result, InferenceError> { + ) -> Result, InferenceError> { let len = text.len(); let mut score = vec![0usize; len + 1]; @@ -1698,7 +1706,6 @@ impl Vocabulary { let max_len = (len - i).min(self.max_token_length); for sub_len in 1..=max_len { let sub = &text.as_bytes()[i..i + sub_len]; - let Ok(sub) = std::str::from_utf8(sub) else { continue; }; let token = self.token_to_id.get(sub); if let Some(token) = token { @@ -1722,14 +1729,14 @@ impl Vocabulary { if token_id == 0 { return Err(InferenceError::TokenizationFailed); } - let token = self.id_to_token[token_id as usize].as_str(); + let token = self.id_to_token[token_id as usize].as_slice(); res.push((token, token_id)); i -= token.len(); } if bos { // TODO: replace with vocab.bos - res.push(("", 1)); + res.push((&[], 1)); } // Pieces are in reverse order so correct that @@ -1738,3 +1745,41 @@ impl Vocabulary { Ok(res) } } + +/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. +/// +/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8 +/// from multiple tokens. This helps alleviate that issue. +#[derive(Clone, PartialEq, Default)] +pub struct TokenUtf8Buffer(Vec); +impl TokenUtf8Buffer { + /// Create a new buffer. + pub const fn new() -> Self { + Self(vec![]) + } + + /// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text, + /// it is returned and the buffer is cleared for next use. + pub fn push(&mut self, token: &[u8]) -> Option { + self.0.extend_from_slice(token); + match std::str::from_utf8(&self.0) { + Ok(s) => { + let out = s.to_owned(); + self.0 = vec![]; + Some(out) + } + Err(..) => None, + } + } + + /// Adapt a `&str` callback so that it can be used in a `&[u8]` context. + fn adapt_callback<'a, E: std::error::Error + 'static>( + mut callback: impl FnMut(&str) -> Result<(), E> + 'a, + ) -> impl FnMut(&[u8]) -> Result<(), E> + 'a { + let mut buffer = Self::new(); + move |token| match buffer.push(token) { + Some(tokens) => callback(&tokens), + None => Ok(()), + } + } +} From 6b1488f655a65c7d5be244efc456ba011e017072 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Apr 2023 02:03:42 +0200 Subject: [PATCH 2/2] Address review feedback --- ggml/src/lib.rs | 2 +- llama-rs/src/convert.rs | 3 --- llama-rs/src/lib.rs | 45 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 76a7e4ab..142dfc79 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -406,7 +406,7 @@ impl Tensor { } } - fn with_alive_ctx(&self, f: impl Fn() -> U) -> U { + fn with_alive_ctx(&self, mut f: impl FnMut() -> U) -> U { if let Some(_ctx) = self.ctx.upgrade() { f() } else { diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index f83346c0..fc562d48 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -129,9 +129,6 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String fn write_tokens(file: &mut File, vocab: &Vocabulary) -> Result<(), String> { let mut values: Vec = vec![]; for (i, token) in vocab.id_to_token.iter().enumerate() { - // TODO: Not sure what the behaviour should be if the token is not valid UTF-8. - // - // Switching to the HF tokenizer should fix this. let text = if let Ok(token) = std::str::from_utf8(token) { match token { _ if token.contains("") => " \u{2047} ".as_bytes().to_vec(), diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index cfddad6d..4a4c07e8 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -570,7 +570,7 @@ impl Model { pub fn load( path: impl AsRef, n_context_tokens: usize, - load_progress_callback: impl Fn(LoadProgress), + mut load_progress_callback: impl FnMut(LoadProgress), ) -> Result<(Model, Vocabulary), LoadError> { use std::fs::File; use std::io::BufReader; @@ -1768,7 +1768,21 @@ impl TokenUtf8Buffer { self.0 = vec![]; Some(out) } - Err(..) => None, + Err(..) => { + for i in 1..self.0.len() { + let slice = &self.0[i..]; + if slice.is_empty() { + break; + } + + if let Ok(s) = std::str::from_utf8(slice) { + let out = s.to_owned(); + self.0 = vec![]; + return Some(out); + } + } + None + } } } @@ -1783,3 +1797,30 @@ impl TokenUtf8Buffer { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_utf8() { + let mut buffer = TokenUtf8Buffer::new(); + assert_eq!(buffer.push(b"hello").as_deref(), Some("hello")); + assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€")); + } + + #[test] + fn test_partial_utf8() { + let mut buffer = TokenUtf8Buffer::new(); + assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); + assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); + } + + #[test] + fn test_invalid_prelude_for_valid_utf8() { + let mut buffer = TokenUtf8Buffer::new(); + assert_eq!(buffer.push(&[0xD8]).as_deref(), None); + assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); + assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); + } +}