Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grammars: 1.5x faster inference w/ complex grammars (vector reserves / reuses) #6609

Merged
merged 6 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
size_t pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
Expand Down
16 changes: 10 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr) {
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {

std::vector<std::vector<const llama_grammar_element *>> new_stacks;
new_stacks.clear();

for (const auto & stack : stacks) {
if (stack.empty()) {
Expand All @@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
}

return new_stacks;
}

static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
Expand All @@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const std::vector<llama_grammar_candidate> & candidates) {

std::vector<llama_grammar_candidate> rejects;
rejects.reserve(candidates.size());

if (stack.empty()) {
for (const auto & tok : candidates) {
Expand All @@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const llama_grammar_element * stack_pos = stack.back();

std::vector<llama_grammar_candidate> next_candidates;
next_candidates.reserve(candidates.size());

for (const auto & tok : candidates) {
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
Expand Down Expand Up @@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
grammar->stacks = tmp_new_stacks;
}
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
Expand Down
5 changes: 3 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1097,10 +1097,11 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
struct llama_context * ctx
);

std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr);
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);

std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
Expand Down
6 changes: 3 additions & 3 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ number ::= [0-9]+)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
}

Expand Down Expand Up @@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);

// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
Expand Down Expand Up @@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
Expand Down
Loading