grammar : restore llama_grammar_accept signature
ggml-ci
This commit is contained in:
parent
5b01cc8c8e
commit
8c972b69c1
4 changed files with 28 additions and 26 deletions
|
@ -12,24 +12,24 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
|
||||||
const auto cpts = unicode_cpts_from_utf8(input_str);
|
const auto cpts = unicode_cpts_from_utf8(input_str);
|
||||||
|
|
||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
for (const auto & cpt : cpts) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
|
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
||||||
|
|
||||||
if (cur_stacks.empty()) {
|
if (stacks_cur.empty()) {
|
||||||
error_pos = pos;
|
error_pos = pos;
|
||||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
||||||
cur_stacks = prev_stacks;
|
stacks_cur = stacks_prev;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
++pos;
|
++pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto & stack : cur_stacks) {
|
for (const auto & stack : stacks_cur) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -822,12 +822,13 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
||||||
return grammar->stacks;
|
return grammar->stacks;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_stacks llama_grammar_accept(
|
void llama_grammar_accept(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
const uint32_t chr) {
|
const uint32_t chr,
|
||||||
llama_grammar_stacks result;
|
llama_grammar_stacks & stacks_new) {
|
||||||
result.reserve(stacks.size());
|
stacks_new.clear();
|
||||||
|
stacks_new.reserve(stacks.size());
|
||||||
|
|
||||||
for (const auto & stack : stacks) {
|
for (const auto & stack : stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
|
@ -843,11 +844,9 @@ llama_grammar_stacks llama_grammar_accept(
|
||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
new_stack.push_back(pos);
|
new_stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack(rules, new_stack, result);
|
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
@ -1127,9 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
|
llama_grammar_stacks stacks_new;
|
||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it);
|
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
|
||||||
grammar.stacks = std::move(new_stacks);
|
grammar.stacks = std::move(stacks_new);
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar.partial_utf8 = decoded.second;
|
grammar.partial_utf8 = decoded.second;
|
||||||
|
|
|
@ -65,10 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
|
||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
// positions
|
// positions
|
||||||
llama_grammar_stacks llama_grammar_accept(
|
void llama_grammar_accept(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
uint32_t chr);
|
uint32_t chr,
|
||||||
|
llama_grammar_stacks & stacks_new);
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
|
|
|
@ -33,20 +33,20 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
|
||||||
const auto cpts = unicode_cpts_from_utf8(input);
|
const auto cpts = unicode_cpts_from_utf8(input);
|
||||||
|
|
||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
for (const auto & cpt : cpts) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
|
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
||||||
|
|
||||||
if (cur_stacks.empty()) {
|
if (stacks_cur.empty()) {
|
||||||
// no stacks means that the grammar failed to match at this point
|
// no stacks means that the grammar failed to match at this point
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto & stack : cur_stacks) {
|
for (const auto & stack : stacks_cur) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
// An empty stack means that the grammar has been completed
|
// An empty stack means that the grammar has been completed
|
||||||
return true;
|
return true;
|
||||||
|
@ -63,9 +63,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
||||||
auto * grammar = build_grammar(grammar_str);
|
auto * grammar = build_grammar(grammar_str);
|
||||||
|
|
||||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||||
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
|
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
fprintf(stderr, " 🔵 Valid strings:\n");
|
fprintf(stderr, " 🔵 Valid strings:\n");
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
||||||
assert(matched);
|
assert(matched);
|
||||||
|
|
||||||
// Reset the grammar stacks
|
// Reset the grammar stacks
|
||||||
cur_stacks = original_stacks;
|
stacks_cur = stacks_org;
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, " 🟠 Invalid strings:\n");
|
fprintf(stderr, " 🟠 Invalid strings:\n");
|
||||||
|
@ -122,7 +122,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
||||||
assert(!matched);
|
assert(!matched);
|
||||||
|
|
||||||
// Reset the grammar stacks
|
// Reset the grammar stacks
|
||||||
cur_stacks = original_stacks;
|
stacks_cur = stacks_org;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up allocated memory
|
// Clean up allocated memory
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue