sample: maintain token count in penalty sampler context
This commit is contained in:
parent
a89f75e1b7
commit
a5452db6dd
1 changed files with 15 additions and 12 deletions
|
@ -1386,6 +1386,9 @@ struct llama_sampler_penalties {
|
||||||
const bool ignore_eos;
|
const bool ignore_eos;
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
|
// Frequency map to count occurrences of each token in last_tokens
|
||||||
|
std::unordered_map<llama_token, size_t> token_count;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
@ -1398,7 +1401,14 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
|
||||||
|
assert(ctx->token_count.at(ctx->prev.front()) > 0);
|
||||||
|
if (--ctx->token_count[ctx->prev.front()] == 0) {
|
||||||
|
ctx->token_count.erase(ctx->prev.front());
|
||||||
|
}
|
||||||
|
}
|
||||||
ctx->prev.push_back(token);
|
ctx->prev.push_back(token);
|
||||||
|
ctx->token_count[token]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
@ -1450,23 +1460,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a frequency map to count occurrences of each token in last_tokens
|
|
||||||
// TODO: optimize this by maintaining the token count in the sampler context
|
|
||||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
|
||||||
llama_token_cnt token_count;
|
|
||||||
|
|
||||||
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
|
||||||
token_count[ctx->prev.rat(i)]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply frequency and presence penalties to the cur_p
|
// Apply frequency and presence penalties to the cur_p
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
const auto token_iter = token_count.find(cur_p->data[i].id);
|
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
|
||||||
if (token_iter == token_count.end()) {
|
if (token_iter == ctx->token_count.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int count = token_iter->second;
|
const size_t count = token_iter->second;
|
||||||
|
|
||||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
|
@ -1490,6 +1491,7 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
||||||
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
||||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||||
ctx->prev.clear();
|
ctx->prev.clear();
|
||||||
|
ctx->token_count.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
||||||
|
@ -1561,6 +1563,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
/* .penalize_nl = */ penalize_nl,
|
/* .penalize_nl = */ penalize_nl,
|
||||||
/* .ignore_eos = */ ignore_eos,
|
/* .ignore_eos = */ ignore_eos,
|
||||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||||
|
/* .token_count = */ std::unordered_map<llama_token, size_t>(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue