grammars: keep llama_grammar_copy non-quadratic optim for later

This commit is contained in:
Olivier Chafik 2024-04-24 14:28:16 +01:00
parent 24769f9a80
commit 05efa34d92

View file

@ -12818,22 +12818,16 @@ void llama_grammar_free(struct llama_grammar * grammar) {
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints }; llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints };
std::unordered_map<const llama_grammar_element *, const llama_grammar_element *> element_map;
element_map.reserve(std::accumulate(
grammar->rules.begin(), grammar->rules.end(), 0,
[](size_t acc, const std::vector<llama_grammar_element> & rule) {
return acc + rule.size();
}));
for (size_t ir = 0; ir < grammar->rules.size(); ir++) {
for (size_t ie = 0; ie < grammar->rules[ir].size(); ie++) {
element_map[&grammar->rules[ir][ie]] = &result->rules[ir][ie];
}
}
// redirect elements in stacks to point to new rules // redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t is = 0; is < result->stacks.size(); is++) {
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
result->stacks[is][ie] = element_map.at(grammar->stacks[is][ie]); for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
result->stacks[is][ie] = &result->rules[ir0][ir1];
}
}
}
} }
} }