Skip to content

Commit

Permalink
Revert "grammars: 1.5x faster inference w/ complex grammars (vector r…
Browse files Browse the repository at this point in the history
…eserves / reuses) (ggerganov#6609)"

This reverts commit cbaadc9.
  • Loading branch information
Nexesenex committed Apr 11, 2024
1 parent f45d7eb commit 95c8115
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
size_t pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
if (grammar->stacks.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
Expand Down
16 changes: 6 additions & 10 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12202,13 +12202,12 @@ static void llama_grammar_advance_stack(
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
void llama_grammar_accept(
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
const uint32_t chr) {

new_stacks.clear();
std::vector<std::vector<const llama_grammar_element *>> new_stacks;

for (const auto & stack : stacks) {
if (stack.empty()) {
Expand All @@ -12227,6 +12226,8 @@ void llama_grammar_accept(
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
}

return new_stacks;
}

static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
Expand All @@ -12240,7 +12241,6 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const std::vector<llama_grammar_candidate> & candidates) {

std::vector<llama_grammar_candidate> rejects;
rejects.reserve(candidates.size());

if (stack.empty()) {
for (const auto & tok : candidates) {
Expand All @@ -12254,8 +12254,6 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const llama_grammar_element * stack_pos = stack.back();

std::vector<llama_grammar_candidate> next_candidates;
next_candidates.reserve(candidates.size());

for (const auto & tok : candidates) {
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
Expand Down Expand Up @@ -13077,10 +13075,8 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
grammar->stacks = tmp_new_stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
}
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
Expand Down
5 changes: 2 additions & 3 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,10 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
struct llama_context * ctx
);

void llama_grammar_accept(
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
const uint32_t chr);

std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
Expand Down
6 changes: 3 additions & 3 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ number ::= [0-9]+)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
assert(!grammar->stacks.empty());
}

Expand Down Expand Up @@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);

// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
Expand Down Expand Up @@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
Expand Down

0 comments on commit 95c8115

Please sign in to comment.