This commit is contained in:
Zhenwei Jin 2024-12-06 14:25:08 +08:00 committed by GitHub
commit 5309f15b35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1409,6 +1409,9 @@ struct llama_sampler_penalties {
const bool ignore_eos; const bool ignore_eos;
ring_buffer<llama_token> prev; ring_buffer<llama_token> prev;
// Frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, size_t> token_count;
}; };
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@ -1421,7 +1424,14 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
return; return;
} }
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
assert(ctx->token_count.at(ctx->prev.front()) > 0);
if (--ctx->token_count[ctx->prev.front()] == 0) {
ctx->token_count.erase(ctx->prev.front());
}
}
ctx->prev.push_back(token); ctx->prev.push_back(token);
ctx->token_count[token]++;
} }
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@ -1473,23 +1483,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
} }
} }
// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
using llama_token_cnt = std::unordered_map<llama_token, int>;
llama_token_cnt token_count;
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
token_count[ctx->prev.rat(i)]++;
}
// Apply frequency and presence penalties to the cur_p // Apply frequency and presence penalties to the cur_p
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
const auto token_iter = token_count.find(cur_p->data[i].id); const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
if (token_iter == token_count.end()) { if (token_iter == ctx->token_count.end()) {
continue; continue;
} }
const int count = token_iter->second; const size_t count = token_iter->second;
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing. // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
@ -1513,6 +1514,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx; auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.clear(); ctx->prev.clear();
ctx->token_count.clear();
} }
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
@ -1584,6 +1586,7 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalize_nl = */ penalize_nl, /* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos, /* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n), /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ std::unordered_map<llama_token, size_t>(),
}, },
}; };
} }