sample: maintain token count in penalty sampler context
This commit is contained in:
parent
a89f75e1b7
commit
a5452db6dd
1 changed files with 15 additions and 12 deletions
|
@ -1386,6 +1386,9 @@ struct llama_sampler_penalties {
|
|||
const bool ignore_eos;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
// Frequency map to count occurrences of each token in last_tokens
|
||||
std::unordered_map<llama_token, size_t> token_count;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
@ -1398,7 +1401,14 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
|
|||
return;
|
||||
}
|
||||
|
||||
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
|
||||
assert(ctx->token_count.at(ctx->prev.front()) > 0);
|
||||
if (--ctx->token_count[ctx->prev.front()] == 0) {
|
||||
ctx->token_count.erase(ctx->prev.front());
|
||||
}
|
||||
}
|
||||
ctx->prev.push_back(token);
|
||||
ctx->token_count[token]++;
|
||||
}
|
||||
|
||||
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
|
@ -1450,23 +1460,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
|||
}
|
||||
}
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
// TODO: optimize this by maintaining the token count in the sampler context
|
||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||
llama_token_cnt token_count;
|
||||
|
||||
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
||||
token_count[ctx->prev.rat(i)]++;
|
||||
}
|
||||
|
||||
// Apply frequency and presence penalties to the cur_p
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const auto token_iter = token_count.find(cur_p->data[i].id);
|
||||
if (token_iter == token_count.end()) {
|
||||
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
|
||||
if (token_iter == ctx->token_count.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int count = token_iter->second;
|
||||
const size_t count = token_iter->second;
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
|
@ -1490,6 +1491,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
|||
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||
ctx->prev.clear();
|
||||
ctx->token_count.clear();
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
||||
|
@ -1561,6 +1563,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||
/* .penalize_nl = */ penalize_nl,
|
||||
/* .ignore_eos = */ ignore_eos,
|
||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||
/* .token_count = */ std::unordered_map<llama_token, size_t>(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue