grammar : restore llama_grammar_accept signature

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-06 11:58:11 +03:00
parent 5b01cc8c8e
commit 8c972b69c1
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 28 additions and 26 deletions

View file

@ -12,24 +12,24 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
const auto cpts = unicode_cpts_from_utf8(input_str); const auto cpts = unicode_cpts_from_utf8(input_str);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
size_t pos = 0; size_t pos = 0;
for (const auto & cpt : cpts) { for (const auto & cpt : cpts) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
if (cur_stacks.empty()) { if (stacks_cur.empty()) {
error_pos = pos; error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
cur_stacks = prev_stacks; stacks_cur = stacks_prev;
return false; return false;
} }
++pos; ++pos;
} }
for (const auto & stack : cur_stacks) { for (const auto & stack : stacks_cur) {
if (stack.empty()) { if (stack.empty()) {
return true; return true;
} }

View file

@ -822,12 +822,13 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
return grammar->stacks; return grammar->stacks;
} }
llama_grammar_stacks llama_grammar_accept( void llama_grammar_accept(
const llama_grammar_rules & rules, const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks, const llama_grammar_stacks & stacks,
const uint32_t chr) { const uint32_t chr,
llama_grammar_stacks result; llama_grammar_stacks & stacks_new) {
result.reserve(stacks.size()); stacks_new.clear();
stacks_new.reserve(stacks.size());
for (const auto & stack : stacks) { for (const auto & stack : stacks) {
if (stack.empty()) { if (stack.empty()) {
@ -843,11 +844,9 @@ llama_grammar_stacks llama_grammar_accept(
if (!llama_grammar_is_end_of_sequence(pos)) { if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos); new_stack.push_back(pos);
} }
llama_grammar_advance_stack(rules, new_stack, result); llama_grammar_advance_stack(rules, new_stack, stacks_new);
} }
} }
return result;
} }
llama_grammar_candidates llama_grammar_reject_candidates_for_stack( llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@ -1127,9 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
llama_grammar_stacks stacks_new;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it); llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
grammar.stacks = std::move(new_stacks); grammar.stacks = std::move(stacks_new);
} }
grammar.partial_utf8 = decoded.second; grammar.partial_utf8 = decoded.second;

View file

@ -65,10 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
// be positioned at a character range (see `llama_grammar_advance_stack`), and // be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those // produces the N possible stacks if the given char is accepted at those
// positions // positions
llama_grammar_stacks llama_grammar_accept( void llama_grammar_accept(
const llama_grammar_rules & rules, const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks, const llama_grammar_stacks & stacks,
uint32_t chr); uint32_t chr,
llama_grammar_stacks & stacks_new);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules, const llama_grammar_rules & rules,

View file

@ -33,20 +33,20 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
const auto cpts = unicode_cpts_from_utf8(input); const auto cpts = unicode_cpts_from_utf8(input);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
for (const auto & cpt : cpts) { for (const auto & cpt : cpts) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
if (cur_stacks.empty()) { if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point // no stacks means that the grammar failed to match at this point
return false; return false;
} }
} }
for (const auto & stack : cur_stacks) { for (const auto & stack : stacks_cur) {
if (stack.empty()) { if (stack.empty()) {
// An empty stack means that the grammar has been completed // An empty stack means that the grammar has been completed
return true; return true;
@ -63,9 +63,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
auto * grammar = build_grammar(grammar_str); auto * grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test // Save the original grammar stacks so that we can reset after every new string we want to test
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar); const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
fprintf(stderr, " 🔵 Valid strings:\n"); fprintf(stderr, " 🔵 Valid strings:\n");
@ -102,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(matched); assert(matched);
// Reset the grammar stacks // Reset the grammar stacks
cur_stacks = original_stacks; stacks_cur = stacks_org;
} }
fprintf(stderr, " 🟠 Invalid strings:\n"); fprintf(stderr, " 🟠 Invalid strings:\n");
@ -122,7 +122,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(!matched); assert(!matched);
// Reset the grammar stacks // Reset the grammar stacks
cur_stacks = original_stacks; stacks_cur = stacks_org;
} }
// Clean up allocated memory // Clean up allocated memory