grammar : restore llama_grammar_accept signature

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-06 11:58:11 +03:00
parent 5b01cc8c8e
commit 8c972b69c1
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 28 additions and 26 deletions

View file

@ -12,24 +12,24 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
const auto cpts = unicode_cpts_from_utf8(input_str);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
size_t pos = 0;
for (const auto & cpt : cpts) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
if (cur_stacks.empty()) {
if (stacks_cur.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
cur_stacks = prev_stacks;
stacks_cur = stacks_prev;
return false;
}
++pos;
}
for (const auto & stack : cur_stacks) {
for (const auto & stack : stacks_cur) {
if (stack.empty()) {
return true;
}

View file

@ -822,12 +822,13 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
return grammar->stacks;
}
llama_grammar_stacks llama_grammar_accept(
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr) {
llama_grammar_stacks result;
result.reserve(stacks.size());
const uint32_t chr,
llama_grammar_stacks & stacks_new) {
stacks_new.clear();
stacks_new.reserve(stacks.size());
for (const auto & stack : stacks) {
if (stack.empty()) {
@ -843,11 +844,9 @@ llama_grammar_stacks llama_grammar_accept(
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack(rules, new_stack, result);
llama_grammar_advance_stack(rules, new_stack, stacks_new);
}
}
return result;
}
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@ -1127,9 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first;
llama_grammar_stacks stacks_new;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it);
grammar.stacks = std::move(new_stacks);
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
grammar.stacks = std::move(stacks_new);
}
grammar.partial_utf8 = decoded.second;

View file

@ -65,10 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
llama_grammar_stacks llama_grammar_accept(
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
uint32_t chr);
uint32_t chr,
llama_grammar_stacks & stacks_new);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,

View file

@ -33,20 +33,20 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
const auto cpts = unicode_cpts_from_utf8(input);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
for (const auto & cpt : cpts) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
if (cur_stacks.empty()) {
if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}
for (const auto & stack : cur_stacks) {
for (const auto & stack : stacks_cur) {
if (stack.empty()) {
// An empty stack means that the grammar has been completed
return true;
@ -63,9 +63,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
auto * grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
fprintf(stderr, " 🔵 Valid strings:\n");
@ -102,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(matched);
// Reset the grammar stacks
cur_stacks = original_stacks;
stacks_cur = stacks_org;
}
fprintf(stderr, " 🟠 Invalid strings:\n");
@ -122,7 +122,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(!matched);
// Reset the grammar stacks
cur_stacks = original_stacks;
stacks_cur = stacks_org;
}
// Clean up allocated memory