grammar : restore llama_grammar_accept signature
ggml-ci
This commit is contained in:
parent
5b01cc8c8e
commit
8c972b69c1
4 changed files with 28 additions and 26 deletions
|
@ -12,24 +12,24 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
|
|||
const auto cpts = unicode_cpts_from_utf8(input_str);
|
||||
|
||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||
|
||||
size_t pos = 0;
|
||||
for (const auto & cpt : cpts) {
|
||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||
|
||||
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
|
||||
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
||||
|
||||
if (cur_stacks.empty()) {
|
||||
if (stacks_cur.empty()) {
|
||||
error_pos = pos;
|
||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
||||
cur_stacks = prev_stacks;
|
||||
stacks_cur = stacks_prev;
|
||||
return false;
|
||||
}
|
||||
++pos;
|
||||
}
|
||||
|
||||
for (const auto & stack : cur_stacks) {
|
||||
for (const auto & stack : stacks_cur) {
|
||||
if (stack.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -822,12 +822,13 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
|||
return grammar->stacks;
|
||||
}
|
||||
|
||||
llama_grammar_stacks llama_grammar_accept(
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
const uint32_t chr) {
|
||||
llama_grammar_stacks result;
|
||||
result.reserve(stacks.size());
|
||||
const uint32_t chr,
|
||||
llama_grammar_stacks & stacks_new) {
|
||||
stacks_new.clear();
|
||||
stacks_new.reserve(stacks.size());
|
||||
|
||||
for (const auto & stack : stacks) {
|
||||
if (stack.empty()) {
|
||||
|
@ -843,11 +844,9 @@ llama_grammar_stacks llama_grammar_accept(
|
|||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
new_stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_advance_stack(rules, new_stack, result);
|
||||
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
|
@ -1127,9 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
llama_grammar_stacks stacks_new;
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it);
|
||||
grammar.stacks = std::move(new_stacks);
|
||||
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
|
||||
grammar.stacks = std::move(stacks_new);
|
||||
}
|
||||
|
||||
grammar.partial_utf8 = decoded.second;
|
||||
|
|
|
@ -65,10 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
|
|||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||
// produces the N possible stacks if the given char is accepted at those
|
||||
// positions
|
||||
llama_grammar_stacks llama_grammar_accept(
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
uint32_t chr);
|
||||
uint32_t chr,
|
||||
llama_grammar_stacks & stacks_new);
|
||||
|
||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
|
|
|
@ -33,20 +33,20 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
|
|||
const auto cpts = unicode_cpts_from_utf8(input);
|
||||
|
||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||
|
||||
for (const auto & cpt : cpts) {
|
||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||
|
||||
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
|
||||
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
||||
|
||||
if (cur_stacks.empty()) {
|
||||
if (stacks_cur.empty()) {
|
||||
// no stacks means that the grammar failed to match at this point
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & stack : cur_stacks) {
|
||||
for (const auto & stack : stacks_cur) {
|
||||
if (stack.empty()) {
|
||||
// An empty stack means that the grammar has been completed
|
||||
return true;
|
||||
|
@ -63,9 +63,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||
auto * grammar = build_grammar(grammar_str);
|
||||
|
||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
|
||||
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
|
||||
|
||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||
|
||||
fprintf(stderr, " 🔵 Valid strings:\n");
|
||||
|
||||
|
@ -102,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||
assert(matched);
|
||||
|
||||
// Reset the grammar stacks
|
||||
cur_stacks = original_stacks;
|
||||
stacks_cur = stacks_org;
|
||||
}
|
||||
|
||||
fprintf(stderr, " 🟠 Invalid strings:\n");
|
||||
|
@ -122,7 +122,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||
assert(!matched);
|
||||
|
||||
// Reset the grammar stacks
|
||||
cur_stacks = original_stacks;
|
||||
stacks_cur = stacks_org;
|
||||
}
|
||||
|
||||
// Clean up allocated memory
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue