From 8d503ef48223e372455abd319d09bd37a089914e Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 21 Apr 2024 15:52:25 +0100 Subject: [PATCH] grammars: faster llama_grammar_copy --- llama.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 704c5e24b..aaae00399 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13061,16 +13061,22 @@ void llama_grammar_free(struct llama_grammar * grammar) { struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints }; + std::unordered_map element_map; + element_map.reserve(std::accumulate( + grammar->rules.begin(), grammar->rules.end(), 0, + [](size_t acc, const std::vector & rule) { + return acc + rule.size(); + })); + for (size_t ir = 0; ir < grammar->rules.size(); ir++) { + for (size_t ie = 0; ie < grammar->rules[ir].size(); ie++) { + element_map[&grammar->rules[ir][ie]] = &result->rules[ir][ie]; + } + } + // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; - } - } - } + result->stacks[is][ie] = element_map.at(grammar->stacks[is][ie]); } }