Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for . (any character) token in grammar engine. #6467

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions common/grammar-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ namespace grammar_parser {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = out_elements.size();
out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
HanClinto marked this conversation as resolved.
Show resolved Hide resolved
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
Expand Down Expand Up @@ -401,6 +405,7 @@ namespace grammar_parser {
case LLAMA_GRETYPE_CHAR_NOT: return true;
case LLAMA_GRETYPE_CHAR_ALT: return true;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
case LLAMA_GRETYPE_CHAR_ANY: return true;
default: return false;
}
}
Expand All @@ -415,6 +420,7 @@ namespace grammar_parser {
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
}
switch (elem.type) {
case LLAMA_GRETYPE_END:
Expand All @@ -426,6 +432,7 @@ namespace grammar_parser {
case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
case LLAMA_GRETYPE_CHAR_ALT:
case LLAMA_GRETYPE_CHAR_ANY:
fprintf(file, "(\"");
print_grammar_char(file, elem.value);
fprintf(file, "\") ");
Expand Down Expand Up @@ -483,11 +490,15 @@ namespace grammar_parser {
}
print_grammar_char(file, elem.value);
break;
case LLAMA_GRETYPE_CHAR_ANY:
fprintf(file, ".");
break;
}
if (is_char_element(elem)) {
switch (rule[i + 1].type) {
case LLAMA_GRETYPE_CHAR_ALT:
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
case LLAMA_GRETYPE_CHAR_ANY:
break;
default:
fprintf(file, "] ");
Expand Down
12 changes: 10 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13640,7 +13640,7 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
const uint32_t chr) {

bool found = false;
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;

GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT

Expand All @@ -13649,6 +13649,10 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
// inclusive range, e.g. [a-z]
found = found || (pos->value <= chr && chr <= pos[1].value);
pos += 2;
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
// Any character matches "."
found = true;
pos += 1;
} else {
// exact char match, e.g. [a] or "a"
found = found || pos->value == chr;
Expand All @@ -13666,7 +13670,7 @@ static bool llama_grammar_match_partial_char(
const llama_grammar_element * pos,
const llama_partial_utf8 partial_utf8) {

bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);

uint32_t partial_value = partial_utf8.value;
Expand Down Expand Up @@ -13696,6 +13700,9 @@ static bool llama_grammar_match_partial_char(
return is_positive_char;
}
pos += 2;
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
// Any character matches "."
return true;
} else {
// exact char match, e.g. [a] or "a"
if (low <= pos->value && pos->value <= high) {
Expand Down Expand Up @@ -13756,6 +13763,7 @@ static void llama_grammar_advance_stack(
}
case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_ANY:
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
// only add the stack if it's not a duplicate of one we already have
new_stacks.emplace_back(stack);
Expand Down
3 changes: 3 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ extern "C" {
// modifies a preceding LLAMA_GRETYPE_CHAR or
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
LLAMA_GRETYPE_CHAR_ALT = 6,

// any character (.)
LLAMA_GRETYPE_CHAR_ANY = 7,
};

typedef struct llama_grammar_element {
Expand Down
28 changes: 28 additions & 0 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,33 @@ static void test_complex_grammar() {
);
}

static void test_special_chars() {
// A collection of tests to exercise special characters such as "."
test_grammar(
"special characters",
// Grammar
R"""(
root ::= ... "abc" ...
)""",
// Passing strings
{
"abcabcabc",
"aaaabcccc",
// NOTE: Also ensures that multi-byte characters still count as a single character
"🔵🟠✅abc❌🟠🔵"
HanClinto marked this conversation as resolved.
Show resolved Hide resolved
},
// Failing strings
{
"aaabcccc",
"aaaaabcccc",
"aaaabccc",
"aaaabccccc",
"🔵🟠✅❌abc❌✅🟠🔵"
"🔵🟠abc🟠🔵"
}
);
}

static void test_quantifiers() {
// A collection of tests to exercise * + and ? quantifiers

Expand Down Expand Up @@ -445,6 +472,7 @@ int main() {
fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar();
test_complex_grammar();
test_special_chars();
test_quantifiers();
test_failure_missing_root();
test_failure_missing_reference();
Expand Down
Loading