From da3dc77989aa52b2dcb5e135e3b8143c5424f9a6 Mon Sep 17 00:00:00 2001 From: Clint Herron Date: Wed, 3 Apr 2024 22:44:19 -0400 Subject: [PATCH] Added support for . (any characer) token in grammar engine. --- common/grammar-parser.cpp | 11 +++++++++++ llama.cpp | 12 ++++++++++-- llama.h | 3 +++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index 2a1301569793ad..0f787bf421c5d4 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -188,6 +188,10 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator if (last_sym_start == out_elements.size()) { throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); @@ -316,6 +320,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; default: return false; } } @@ -330,6 +335,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -341,6 +347,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "(\""); print_grammar_char(file, elem.value); fprintf(file, "\") "); @@ -398,11 +405,15 @@ namespace grammar_parser { } print_grammar_char(file, elem.value); break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: break; default: fprintf(file, "] "); diff --git a/llama.cpp b/llama.cpp index 267ac4cc022a18..83f8885ce30d83 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11716,7 +11716,7 @@ static std::pair llama_grammar_match_char( const uint32_t chr) { bool found = false; - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT @@ -11725,6 +11725,10 @@ static std::pair llama_grammar_match_char( // inclusive range, e.g. [a-z] found = found || (pos->value <= chr && chr <= pos[1].value); pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + found = true; + pos += 1; } else { // exact char match, e.g. [a] or "a" found = found || pos->value == chr; @@ -11742,7 +11746,7 @@ static bool llama_grammar_match_partial_char( const llama_grammar_element * pos, const llama_partial_utf8 partial_utf8) { - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); uint32_t partial_value = partial_utf8.value; @@ -11772,6 +11776,9 @@ static bool llama_grammar_match_partial_char( return is_positive_char; } pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + return true; } else { // exact char match, e.g. [a] or "a" if (low <= pos->value && pos->value <= high) { @@ -11830,6 +11837,7 @@ static void llama_grammar_advance_stack( } case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: new_stacks.emplace_back(stack); break; default: diff --git a/llama.h b/llama.h index f061d014ca8eb0..cb83170218be4f 100644 --- a/llama.h +++ b/llama.h @@ -315,6 +315,9 @@ extern "C" { // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, }; typedef struct llama_grammar_element {