Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

fix(llama): buffer tokens until valid UTF-8 #122

Merged
merged 2 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ impl Tensor {
}
}

fn with_alive_ctx<U>(&self, f: impl Fn() -> U) -> U {
fn with_alive_ctx<U>(&self, mut f: impl FnMut() -> U) -> U {
if let Some(_ctx) = self.ctx.upgrade() {
f()
} else {
Expand Down
3 changes: 0 additions & 3 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ impl ModelLoad {
LoadProgress::HyperparametersLoaded(hparams) => {
log::debug!("Loaded hyperparameters {hparams:#?}")
}
LoadProgress::BadToken { index } => {
log::info!("Warning: Bad token in vocab at index {index}")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
Expand Down
23 changes: 14 additions & 9 deletions llama-rs/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ fn load_vocabulary(path: &Path) -> Vocabulary {
let mut token_to_id = HashMap::new();
let mut max_token_length = 0;

// TODO: Does the original model use valid UTF-8 for its tokens? This seems a little suspect to me.
for (idx, piece) in proto.get_pieces().iter().enumerate() {
let word = piece.get_piece().to_string();
let word = piece.get_piece().as_bytes();
max_token_length = max_token_length.max(word.len());
id_to_token.push(word.clone());
token_to_id.insert(word, idx as i32);
id_to_token.push(word.to_owned());
token_to_id.insert(word.to_owned(), idx as i32);
id_to_token_score.push(piece.get_score());
}
Vocabulary {
Expand Down Expand Up @@ -128,13 +129,17 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String
fn write_tokens(file: &mut File, vocab: &Vocabulary) -> Result<(), String> {
let mut values: Vec<u8> = vec![];
for (i, token) in vocab.id_to_token.iter().enumerate() {
let text = match token {
_ if token.contains("<unk>") => " \u{2047} ".as_bytes().to_vec(),
_ if token.contains("s>") => vec![],
_ if token.len() == 6 && token.contains("<0x") => {
vec![u8::from_str_radix(&token[3..5], 16).unwrap()]
let text = if let Ok(token) = std::str::from_utf8(token) {
match token {
_ if token.contains("<unk>") => " \u{2047} ".as_bytes().to_vec(),
_ if token.contains("s>") => vec![],
_ if token.len() == 6 && token.contains("<0x") => {
vec![u8::from_str_radix(&token[3..5], 16).unwrap()]
}
_ => token.replace('\u{2581}', " ").as_bytes().to_vec(),
}
_ => token.replace('\u{2581}', " ").as_bytes().to_vec(),
} else {
token.clone()
};
values.extend((text.len() as i32).to_le_bytes());
values.extend(&text);
Expand Down
166 changes: 126 additions & 40 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ impl Display for InferenceStats {
}

type TokenId = i32;
type Token = String;
type Token = Vec<u8>;
type TokenScore = f32;

/// The vocabulary used by a model.
Expand All @@ -309,7 +309,7 @@ pub struct Vocabulary {
max_token_length: usize,
}
impl Vocabulary {
fn token(&self, idx: usize) -> &str {
fn token(&self, idx: usize) -> &[u8] {
&self.id_to_token[idx]
}
}
Expand Down Expand Up @@ -386,14 +386,6 @@ impl std::fmt::Display for TokenBias {
pub enum LoadProgress<'a> {
/// The hyperparameters have been loaded from the model.
HyperparametersLoaded(&'a Hyperparameters),
/// A bad token was encountered during the loading process.
///
/// This can be ignored, but invalid tokens will be replaced with
/// the `�` character.
BadToken {
/// The index within the vocabulary.
index: usize,
},
philpax marked this conversation as resolved.
Show resolved Hide resolved
/// The context has been created.
ContextSize {
/// The size of the context.
Expand Down Expand Up @@ -578,7 +570,7 @@ impl Model {
pub fn load(
path: impl AsRef<Path>,
n_context_tokens: usize,
load_progress_callback: impl Fn(LoadProgress),
mut load_progress_callback: impl FnMut(LoadProgress),
) -> Result<(Model, Vocabulary), LoadError> {
use std::fs::File;
use std::io::BufReader;
Expand All @@ -604,6 +596,20 @@ impl Model {
Ok(bytes)
}

fn read_bytes_with_len(
reader: &mut impl BufRead,
len: usize,
) -> Result<Vec<u8>, LoadError> {
let mut bytes = vec![0u8; len];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: len,
})?;
Ok(bytes)
}

fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}
Expand All @@ -618,15 +624,7 @@ impl Model {

/// Helper function. Reads a string from the buffer and returns it.
fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
let s = String::from_utf8(buf)?;
Ok(s)
Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?)
}

// Verify magic
Expand Down Expand Up @@ -682,14 +680,10 @@ impl Model {

for i in 0..hparams.n_vocab {
let len = read_i32(&mut reader)?;
if let Ok(word) = read_string(&mut reader, len as usize) {
max_token_length = max_token_length.max(word.len());
id_to_token.push(word.clone());
token_to_id.insert(word, TokenId::try_from(i)?);
} else {
load_progress_callback(LoadProgress::BadToken { index: i });
id_to_token.push("�".to_string());
}
let token = read_bytes_with_len(&mut reader, len as usize)?;
max_token_length = max_token_length.max(token.len());
id_to_token.push(token.clone());
token_to_id.insert(token, TokenId::try_from(i)?);

// Token score, currently unused
if !is_legacy_model {
Expand Down Expand Up @@ -1427,7 +1421,7 @@ impl InferenceSession {
vocab: &Vocabulary,
params: &InferenceParameters,
prompt: &str,
callback: impl Fn(&str) -> Result<(), E>,
mut callback: impl FnMut(&[u8]) -> Result<(), E>,
) -> Result<(), InferenceError> {
let beginning_of_sentence = self.n_past == 0;
let prompt_tokens: Vec<TokenId> = vocab
Expand Down Expand Up @@ -1464,7 +1458,7 @@ impl InferenceSession {
vocab: &'v Vocabulary,
params: &InferenceParameters,
rng: &mut impl rand::Rng,
) -> Result<&'v str, InferenceError> {
) -> Result<&'v [u8], InferenceError> {
if self.n_past + 1 >= model.hparams.n_ctx {
return Err(InferenceError::ContextFull);
}
Expand Down Expand Up @@ -1505,15 +1499,19 @@ impl InferenceSession {
prompt: &str,
maximum_token_count: Option<usize>,
rng: &mut impl rand::Rng,
callback: impl Fn(&str) -> Result<(), E>,
mut callback: impl FnMut(&str) -> Result<(), E>,
philpax marked this conversation as resolved.
Show resolved Hide resolved
) -> Result<InferenceStats, InferenceError> {
let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX);
if params.play_back_previous_tokens {
// "Play back" the existing tokens, so that loading from an inference snapshot works
// as expected.
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in &self.tokens {
if let Err(e) = callback(vocab.token(*token_id as usize)) {
return Err(InferenceError::UserCallback(Box::new(e)));
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(vocab.token(*token_id as usize)) {
if let Err(e) = callback(&tokens) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
}
}
Expand All @@ -1524,7 +1522,13 @@ impl InferenceSession {

// Feed the initial prompt through the transformer, to update its
// context window with new data.
self.feed_prompt(model, vocab, params, prompt, |tk| callback(tk))?;
self.feed_prompt(
model,
vocab,
params,
prompt,
TokenUtf8Buffer::adapt_callback(&mut callback),
)?;
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

Expand All @@ -1533,15 +1537,19 @@ impl InferenceSession {
// EndOfText token, or we run out of space in the context window,
// or we reach the specified limit.
let mut tokens_processed = 0;
let mut token_utf8_buf = TokenUtf8Buffer::new();
while tokens_processed < maximum_token_count {
let token = match self.infer_next_token(model, vocab, params, rng) {
Ok(token) => token,
Err(InferenceError::EndOfText) => break,
Err(e) => return Err(e),
};

if let Err(e) = callback(token) {
return Err(InferenceError::UserCallback(Box::new(e)));
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(token) {
if let Err(e) = callback(&tokens) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}

tokens_processed += 1;
Expand Down Expand Up @@ -1688,7 +1696,7 @@ impl Vocabulary {
&'a self,
text: &str,
bos: bool,
) -> Result<Vec<(&'a str, TokenId)>, InferenceError> {
) -> Result<Vec<(&'a [u8], TokenId)>, InferenceError> {
let len = text.len();

let mut score = vec![0usize; len + 1];
Expand All @@ -1698,7 +1706,6 @@ impl Vocabulary {
let max_len = (len - i).min(self.max_token_length);
for sub_len in 1..=max_len {
let sub = &text.as_bytes()[i..i + sub_len];
let Ok(sub) = std::str::from_utf8(sub) else { continue; };
let token = self.token_to_id.get(sub);

if let Some(token) = token {
Expand All @@ -1722,14 +1729,14 @@ impl Vocabulary {
if token_id == 0 {
return Err(InferenceError::TokenizationFailed);
}
let token = self.id_to_token[token_id as usize].as_str();
let token = self.id_to_token[token_id as usize].as_slice();
res.push((token, token_id));
i -= token.len();
}

if bos {
// TODO: replace with vocab.bos
res.push(("", 1));
res.push((&[], 1));
}

// Pieces are in reverse order so correct that
Expand All @@ -1738,3 +1745,82 @@ impl Vocabulary {
Ok(res)
}
}

/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text.
///
/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8
/// from multiple tokens. This helps alleviate that issue.
#[derive(Clone, PartialEq, Default)]
pub struct TokenUtf8Buffer(Vec<u8>);
impl TokenUtf8Buffer {
/// Create a new buffer.
pub const fn new() -> Self {
Self(vec![])
}

/// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text,
/// it is returned and the buffer is cleared for next use.
philpax marked this conversation as resolved.
Show resolved Hide resolved
pub fn push(&mut self, token: &[u8]) -> Option<String> {
self.0.extend_from_slice(token);
match std::str::from_utf8(&self.0) {
Ok(s) => {
let out = s.to_owned();
self.0 = vec![];
Some(out)
}
Err(..) => {
for i in 1..self.0.len() {
let slice = &self.0[i..];
if slice.is_empty() {
break;
}

if let Ok(s) = std::str::from_utf8(slice) {
let out = s.to_owned();
self.0 = vec![];
return Some(out);
}
}
None
}
}
}

/// Adapt a `&str` callback so that it can be used in a `&[u8]` context.
fn adapt_callback<'a, E: std::error::Error + 'static>(
mut callback: impl FnMut(&str) -> Result<(), E> + 'a,
) -> impl FnMut(&[u8]) -> Result<(), E> + 'a {
let mut buffer = Self::new();
move |token| match buffer.push(token) {
Some(tokens) => callback(&tokens),
None => Ok(()),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_valid_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(b"hello").as_deref(), Some("hello"));
assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€"));
}

#[test]
fn test_partial_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
}

#[test]
fn test_invalid_prelude_for_valid_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(&[0xD8]).as_deref(), None);
assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
}
}