From a5452db6dd3c7af14c42a2ed797359062e9aec2d Mon Sep 17 00:00:00 2001 From: zhenweijin Date: Thu, 17 Oct 2024 16:19:57 +0800 Subject: [PATCH] sample: maintain token count in penalty sampler context --- src/llama-sampling.cpp | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e255a8fc4..da2feb2cc 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1386,6 +1386,9 @@ struct llama_sampler_penalties { const bool ignore_eos; ring_buffer prev; + + // Frequency map to count occurrences of each token in last_tokens + std::unordered_map token_count; }; static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { @@ -1398,7 +1401,14 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to return; } + if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { + assert(ctx->token_count.at(ctx->prev.front()) > 0); + if (--ctx->token_count[ctx->prev.front()] == 0) { + ctx->token_count.erase(ctx->prev.front()); + } + } ctx->prev.push_back(token); + ctx->token_count[token]++; } static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1450,23 +1460,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok } } - // Create a frequency map to count occurrences of each token in last_tokens - // TODO: optimize this by maintaining the token count in the sampler context - using llama_token_cnt = std::unordered_map; - llama_token_cnt token_count; - - for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { - token_count[ctx->prev.rat(i)]++; - } - // Apply frequency and presence penalties to the cur_p for (size_t i = 0; i < cur_p->size; ++i) { - const auto token_iter = token_count.find(cur_p->data[i].id); - if (token_iter == token_count.end()) { + const auto token_iter = ctx->token_count.find(cur_p->data[i].id); + if (token_iter == ctx->token_count.end()) { continue; } - const int count = token_iter->second; + const size_t count = token_iter->second; // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. @@ -1490,6 +1491,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; ctx->prev.clear(); + ctx->token_count.clear(); } static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { @@ -1561,6 +1563,7 @@ struct llama_sampler * llama_sampler_init_penalties( /* .penalize_nl = */ penalize_nl, /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n), + /* .token_count = */ std::unordered_map(), }, }; }