sampling : refactor + optimize penalties sampler (#10803)
* sampling : refactor + optimize penalties sampler ggml-ci * common : apply ignore_eos as logit bias ggml-ci * batched : remove penalties sampler * params : allow penalty_last_n == -1 to be equal to context size ggml-ci * common : by default, move the penalties at the end of the sampling chain ggml-ci * common : ignore all EOG tokens Co-authored-by: Diego Devesa <slarengh@gmail.com> * common : move back the penalties at the front of the sampling chain ggml-ci * readme : restore hint about --ignore-eos flag [no ci] * llama : minor ggml-ci * webui : update --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
4ddd199f6f
commit
644fd71b44
17 changed files with 111 additions and 152 deletions
|
@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
|
|||
// penalties
|
||||
|
||||
struct llama_sampler_penalties {
|
||||
const int32_t n_vocab;
|
||||
const llama_token special_eos_id;
|
||||
const llama_token linefeed_id;
|
||||
|
||||
const int32_t penalty_last_n;
|
||||
const float penalty_repeat;
|
||||
const float penalty_freq;
|
||||
const float penalty_present;
|
||||
|
||||
const bool penalize_nl;
|
||||
const bool ignore_eos;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
// a frequency map to count token occurrences
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
|
|||
return;
|
||||
}
|
||||
|
||||
ctx->token_count[token]++;
|
||||
|
||||
// if the ring buffer is full, remove the oldest token
|
||||
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
|
||||
const auto old = ctx->prev.front();
|
||||
|
||||
ctx->token_count[old]--;
|
||||
if (ctx->token_count[old] == 0) {
|
||||
ctx->token_count.erase(old);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->prev.push_back(token);
|
||||
|
||||
#if 0
|
||||
// sanity check
|
||||
std::unordered_map<llama_token, int> tmp;
|
||||
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
||||
tmp[ctx->prev.rat(i)]++;
|
||||
}
|
||||
|
||||
assert(ctx->token_count == tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||
|
||||
if (ctx->ignore_eos) {
|
||||
assert(ctx->special_eos_id >= 0);
|
||||
|
||||
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
||||
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
||||
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
||||
} else {
|
||||
// else, search for the special EOS token
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].id == ctx->special_eos_id) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ((ctx->penalty_last_n == 0) ||
|
||||
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool nl_found = false;
|
||||
size_t nl_idx = 0;
|
||||
float nl_logit = -INFINITY;
|
||||
if (!ctx->penalize_nl) {
|
||||
assert(ctx->linefeed_id >= 0);
|
||||
|
||||
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
||||
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
||||
nl_found = true;
|
||||
nl_idx = ctx->linefeed_id;
|
||||
nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
||||
} else {
|
||||
// else, search for the linefeed token
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].id == ctx->linefeed_id) {
|
||||
nl_found = true;
|
||||
nl_idx = i;
|
||||
nl_logit = cur_p->data[i].logit;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
// TODO: optimize this by maintaining the token count in the sampler context
|
||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||
llama_token_cnt token_count;
|
||||
|
||||
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
||||
token_count[ctx->prev.rat(i)]++;
|
||||
}
|
||||
|
||||
// Apply frequency and presence penalties to the cur_p
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const auto token_iter = token_count.find(cur_p->data[i].id);
|
||||
if (token_iter == token_count.end()) {
|
||||
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
|
||||
if (token_iter == ctx->token_count.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int count = token_iter->second;
|
||||
|
||||
assert(count > 0 && count <= ctx->penalty_last_n);
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
if (cur_p->data[i].logit <= 0) {
|
||||
|
@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
|||
}
|
||||
|
||||
cur_p->sorted = false;
|
||||
|
||||
if (!ctx->penalize_nl && nl_found) {
|
||||
// restore the logit of the newline token if it was penalized
|
||||
cur_p->data[nl_idx].logit = nl_logit;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||
ctx->prev.clear();
|
||||
ctx->token_count.clear();
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_penalties(
|
||||
ctx->n_vocab,
|
||||
ctx->special_eos_id,
|
||||
ctx->linefeed_id,
|
||||
ctx->penalty_last_n,
|
||||
ctx->penalty_repeat,
|
||||
ctx->penalty_freq,
|
||||
ctx->penalty_present,
|
||||
ctx->penalize_nl,
|
||||
ctx->ignore_eos);
|
||||
ctx->penalty_present);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
|
@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_penalties(
|
||||
int32_t n_vocab,
|
||||
llama_token special_eos_id,
|
||||
llama_token linefeed_id,
|
||||
int32_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present,
|
||||
bool penalize_nl,
|
||||
bool ignore_eos) {
|
||||
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
||||
penalize_nl = true;
|
||||
}
|
||||
|
||||
if (special_eos_id == LLAMA_TOKEN_NULL) {
|
||||
ignore_eos = false;
|
||||
}
|
||||
|
||||
float penalty_present) {
|
||||
penalty_last_n = std::max(penalty_last_n, 0);
|
||||
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_penalties_i,
|
||||
/* .ctx = */ new llama_sampler_penalties {
|
||||
/* .n_vocab = */ n_vocab,
|
||||
/* .special_eos_id = */ special_eos_id,
|
||||
/* .linefeed_id = */ linefeed_id,
|
||||
/* .penalty_last_n = */ penalty_last_n,
|
||||
/* .penalty_repeat = */ penalty_repeat,
|
||||
/* .penalty_freq = */ penalty_freq,
|
||||
/* .penalty_present = */ penalty_present,
|
||||
/* .penalize_nl = */ penalize_nl,
|
||||
/* .ignore_eos = */ ignore_eos,
|
||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||
/* .token_count = */ {},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
|
|||
if (word.find(str) != std::string::npos) {
|
||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||
} else {
|
||||
size_t word_len = word.size(), str_len = str.size();
|
||||
size_t word_len = word.size();
|
||||
size_t str_len = str.size();
|
||||
size_t pos = -1;
|
||||
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
||||
bool match = true;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue