diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8bbe751e0..bebff77cf 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1421,15 +1421,25 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to // if the ring buffer is full, remove the oldest token if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { - const auto pop = ctx->prev.front(); + const auto old = ctx->prev.front(); - ctx->token_count[pop]--; - if (ctx->token_count[pop] == 0) { - ctx->token_count.erase(pop); + ctx->token_count[old]--; + if (ctx->token_count[old] == 0) { + ctx->token_count.erase(old); } } ctx->prev.push_back(token); + +#if 0 + // sanity check + std::unordered_map tmp; + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { + tmp[ctx->prev.rat(i)]++; + } + + assert(ctx->token_count == tmp); +#endif } static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1449,7 +1459,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok const int count = token_iter->second; - assert(count > 0); + assert(count > 0 && count <= ctx->penalty_last_n); // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing.