grammars: reuse new_stacks

This commit is contained in:
Olivier Chafik 2024-04-11 15:11:40 +01:00
parent 3732ad9c22
commit 47e37dd955

View file

@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
// be positioned at a character range (see `llama_grammar_advance_stack`), and // be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those // produces the N possible stacks if the given char is accepted at those
// positions // positions
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept( void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules, const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks, const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr) { const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
std::vector<std::vector<const llama_grammar_element *>> new_stacks; new_stacks.clear();
for (const auto & stack : stacks) { for (const auto & stack : stacks) {
if (stack.empty()) { if (stack.empty()) {
@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
llama_grammar_advance_stack(rules, new_stack, new_stacks); llama_grammar_advance_stack(rules, new_stack, new_stacks);
} }
} }
return new_stacks;
} }
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates( static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
@ -12774,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
tmp_new_stacks.swap(grammar->stacks);
} }
grammar->partial_utf8 = decoded.second; grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty()); GGML_ASSERT(!grammar->stacks.empty());