Skip to content

Commit

Permalink
Expanding tests to add chained-ambiguity stress tests as well as simp…
Browse files Browse the repository at this point in the history
…le timing metrics.
  • Loading branch information
HanClinto committed Apr 11, 2024
1 parent f4183af commit 490d06f
Showing 1 changed file with 128 additions and 4 deletions.
132 changes: 128 additions & 4 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,107 @@ ws ::= [ \t\n\r]?)""";
llama_grammar_free(grammar);
}

static void test_chained_ambiguity() {
// Test case for a grammar that has chained ambiguity
const std::string grammar_str = R"""(root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? [0-9])*)""";
// const std::string grammar_str = R"""(root ::= [0-9] (("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a")?)?)?)?)?)?)?)?)?)? [0-9])*)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());

// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));

std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));

std::string input = "1aa2aa3aa4aa5";

auto decoded = decode_utf8(input, {});

const auto & code_points = decoded.first;

size_t cnt = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
//fprintf(stderr, "Parsing character %zu ('%c'), stack size %zu\n", cnt, input[cnt], grammar->stacks.size());
++cnt;

auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
if (grammar->stacks.empty()) {
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
}
assert(!grammar->stacks.empty());
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}

assert(completed_grammar);

// Clean up allocated memory
llama_grammar_free(grammar);
}

static void test_chained_ambiguity_grouped() {
// Test case for a grammar that has chained ambiguity
const std::string grammar_str = R"""(root ::= [0-9] (("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a")?)?)?)?)?)?)?)?)?)? [0-9])*)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());

// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));

std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));

std::string input = "1aa2aa3aa4aa5";

auto decoded = decode_utf8(input, {});

const auto & code_points = decoded.first;

size_t cnt = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
//fprintf(stderr, "Parsing character %zu ('%c'), stack size %zu\n", cnt, input[cnt], grammar->stacks.size());
++cnt;

auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
if (grammar->stacks.empty()) {
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
}
assert(!grammar->stacks.empty());
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}

assert(completed_grammar);

// Clean up allocated memory
llama_grammar_free(grammar);
}

static void test_failure_missing_root() {
// Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr
Expand Down Expand Up @@ -234,10 +335,33 @@ number ::= [0-9]+)""";
fprintf(stderr, "End of expected error. Test successful.\n");
}

static std::vector<int64_t> times;
static std::vector<std::string> time_labels;

typedef void (*bench_func)(void);

static void bench(bench_func func, const char* label = "") {
func();
times.push_back(ggml_time_us());
time_labels.push_back(label);
}

int main() {
test_simple_grammar();
test_complex_grammar();
test_failure_missing_root();
test_failure_missing_reference();
ggml_time_init();
times.push_back(ggml_time_us());
time_labels.push_back("Start");
bench(test_simple_grammar, "Simple grammar");
bench(test_complex_grammar, "Complex grammar");
bench(test_chained_ambiguity, "Chained ambiguity");
bench(test_chained_ambiguity_grouped, "Chained ambiguity (grouped)");
bench(test_failure_missing_root, "Failure missing root");
bench(test_failure_missing_reference, "Failure missing reference");

// Print timings
fprintf(stdout, "\nTimings:\n");
for (size_t i = 1; i < times.size(); ++i) {
fprintf(stdout, "%s: %lld us\n", time_labels[i].c_str(), times[i] - times[i - 1]);
}

return 0;
}

0 comments on commit 490d06f

Please sign in to comment.