This commit is contained in:
Zhenwei Jin 2024-12-06 14:25:08 +08:00 committed by GitHub
commit 5309f15b35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1409,6 +1409,9 @@ struct llama_sampler_penalties {
const bool ignore_eos;
ring_buffer<llama_token> prev;
// Frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, size_t> token_count;
};
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@ -1421,7 +1424,14 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
return;
}
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
assert(ctx->token_count.at(ctx->prev.front()) > 0);
if (--ctx->token_count[ctx->prev.front()] == 0) {
ctx->token_count.erase(ctx->prev.front());
}
}
ctx->prev.push_back(token);
ctx->token_count[token]++;
}
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -1473,23 +1483,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
}
}
// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
using llama_token_cnt = std::unordered_map<llama_token, int>;
llama_token_cnt token_count;
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
token_count[ctx->prev.rat(i)]++;
}
// Apply frequency and presence penalties to the cur_p
for (size_t i = 0; i < cur_p->size; ++i) {
const auto token_iter = token_count.find(cur_p->data[i].id);
if (token_iter == token_count.end()) {
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
if (token_iter == ctx->token_count.end()) {
continue;
}
const int count = token_iter->second;
const size_t count = token_iter->second;
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
@ -1513,6 +1514,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.clear();
ctx->token_count.clear();
}
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
@ -1584,6 +1586,7 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ std::unordered_map<llama_token, size_t>(),
},
};
}