From b40e1c7c75a38b9ed590a39c971e49831b634507 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 8 Jan 2025 12:34:51 +0100 Subject: [PATCH] allow for using -1 for eos token (if missing); untested --- parser/src/earley/parser.rs | 7 ++++--- parser/src/tokenparser.rs | 4 ++-- toktrie/src/lib.rs | 1 + toktrie/src/toktree.rs | 4 +++- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/parser/src/earley/parser.rs b/parser/src/earley/parser.rs index e9fb965..e033f5d 100644 --- a/parser/src/earley/parser.rs +++ b/parser/src/earley/parser.rs @@ -16,7 +16,7 @@ use derivre::{AlphabetInfo, RegexAst, StateID}; use hashbrown::HashSet; use instant::Instant; use serde::{Deserialize, Serialize}; -use toktrie::{Recognizer, SimpleVob, TokEnv, TokTrie}; +use toktrie::{Recognizer, SimpleVob, TokEnv, TokTrie, INVALID_TOKEN}; use crate::{ api::{ParserLimits, StopReason}, @@ -663,8 +663,9 @@ impl ParserState { let _ = self.flush_lexer(); } - if start.is_empty() && self.lexer_allows_eos() { - set.allow_token(computer.trie().eos_token()); + let eos = computer.trie().eos_token(); + if eos != INVALID_TOKEN && start.is_empty() && self.lexer_allows_eos() { + set.allow_token(eos); } self.stats.compute_time_us += t0.elapsed().as_micros() as u64; diff --git a/parser/src/tokenparser.rs b/parser/src/tokenparser.rs index e02826e..aea0c4b 100644 --- a/parser/src/tokenparser.rs +++ b/parser/src/tokenparser.rs @@ -9,7 +9,7 @@ use crate::{ }; use anyhow::{ensure, Result}; use serde_json::json; -use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId}; +use toktrie::{InferenceCapabilities, SimpleVob, TokEnv, TokenId, INVALID_TOKEN}; #[derive(Clone)] pub struct TokenParser { @@ -386,7 +386,7 @@ impl TokenParser { return Err(self.stop_for_parser_error("", s)); } - if self.is_accepting() { + if self.eos_token != INVALID_TOKEN && self.is_accepting() { allowed_tokens.allow_token(self.eos_token); } diff --git a/toktrie/src/lib.rs b/toktrie/src/lib.rs index e3f98f3..ee4ffb2 100644 --- a/toktrie/src/lib.rs +++ b/toktrie/src/lib.rs @@ -9,6 +9,7 @@ mod toktree; pub use svob::{SimpleVob, SimpleVobIter}; pub use toktree::{ Recognizer, TokEnv, TokEnvWithTrie, TokRxInfo, TokTrie, TokenId, TokenizerEnv, TrieNode, + INVALID_TOKEN, }; /// Defines what is allowed in Branch diff --git a/toktrie/src/toktree.rs b/toktrie/src/toktree.rs index 5605f78..e566f42 100644 --- a/toktrie/src/toktree.rs +++ b/toktrie/src/toktree.rs @@ -211,6 +211,8 @@ pub struct TrieNode { bits2: u32, } +pub const INVALID_TOKEN: TokenId = 0xffff_ffff; + const NO_TOKEN: u32 = 0xffffff; impl TrieNode { @@ -345,7 +347,7 @@ impl TokTrie { let max_tok = std::cmp::min(max_examples, num_set); let mut token_names = Vec::new(); // make sure we include EOS first if it's allowed - if ts1.is_allowed(self.info.tok_eos) { + if self.info.tok_eos != INVALID_TOKEN && ts1.is_allowed(self.info.tok_eos) { token_names.push("EOS".to_string()); } for idx in 0..self.vocab_size() {