From 8c972b69c167a10423c38c2e7f89f11e56d2cddf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 11:58:11 +0300 Subject: [PATCH] grammar : restore llama_grammar_accept signature ggml-ci --- examples/gbnf-validator/gbnf-validator.cpp | 12 ++++++------ src/llama-grammar.cpp | 19 ++++++++++--------- src/llama-grammar.h | 5 +++-- tests/test-grammar-integration.cpp | 18 +++++++++--------- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index f439c0c56..7493af9d3 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -12,24 +12,24 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st const auto cpts = unicode_cpts_from_utf8(input_str); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; for (const auto & cpt : cpts) { - const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy + const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); - if (cur_stacks.empty()) { + if (stacks_cur.empty()) { error_pos = pos; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; - cur_stacks = prev_stacks; + stacks_cur = stacks_prev; return false; } ++pos; } - for (const auto & stack : cur_stacks) { + for (const auto & stack : stacks_cur) { if (stack.empty()) { return true; } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 353cb398a..74e9f64b3 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -822,12 +822,13 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } -llama_grammar_stacks llama_grammar_accept( +void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - const uint32_t chr) { - llama_grammar_stacks result; - result.reserve(stacks.size()); + const uint32_t chr, + llama_grammar_stacks & stacks_new) { + stacks_new.clear(); + stacks_new.reserve(stacks.size()); for (const auto & stack : stacks) { if (stack.empty()) { @@ -843,11 +844,9 @@ llama_grammar_stacks llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, result); + llama_grammar_advance_stack(rules, new_stack, stacks_new); } } - - return result; } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -1127,9 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; + llama_grammar_stacks stacks_new; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it); - grammar.stacks = std::move(new_stacks); + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); + grammar.stacks = std::move(stacks_new); } grammar.partial_utf8 = decoded.second; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 419a616d6..f529ce351 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -65,10 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -llama_grammar_stacks llama_grammar_accept( +void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - uint32_t chr); + uint32_t chr, + llama_grammar_stacks & stacks_new); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 788b02a6a..5cc0cdb04 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -33,20 +33,20 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { const auto cpts = unicode_cpts_from_utf8(input); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); for (const auto & cpt : cpts) { - const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy + const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); - if (cur_stacks.empty()) { + if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point return false; } } - for (const auto & stack : cur_stacks) { + for (const auto & stack : stacks_cur) { if (stack.empty()) { // An empty stack means that the grammar has been completed return true; @@ -63,9 +63,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str, auto * grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test - const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar); + const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); fprintf(stderr, " 🔵 Valid strings:\n"); @@ -102,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, assert(matched); // Reset the grammar stacks - cur_stacks = original_stacks; + stacks_cur = stacks_org; } fprintf(stderr, " 🟠 Invalid strings:\n"); @@ -122,7 +122,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, assert(!matched); // Reset the grammar stacks - cur_stacks = original_stacks; + stacks_cur = stacks_org; } // Clean up allocated memory