Swapped sorting for a custom algorithm
Shifts tokens to remove the penalized ones, then puts the penalized at the back. Should make `min_keep` still viable.
This commit is contained in:
parent
59e8e63e68
commit
63e60deda3
1 changed files with 22 additions and 11 deletions
|
@ -1096,28 +1096,39 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|||
// in case it's not sorted/recalculated yet
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
int found = 0;
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
int removed = -1; // to keep one out
|
||||
int pos = 0;
|
||||
|
||||
// going through all candidates from back to front, easier to keep the last of probables
|
||||
for (int i = (cur_p->size - 1); i >= 0; --i) {
|
||||
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
|
||||
++found;
|
||||
if (found > 1) {
|
||||
++removed;
|
||||
if (removed > 0) {
|
||||
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
|
||||
cur_p->data[i].logit = -999.0f;
|
||||
cur.emplace_back(cur_p->data[i]);
|
||||
pos = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (found > 1) {
|
||||
// sorting with new logits, ex-last probable will be the first anyway
|
||||
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
if (removed > 0) {
|
||||
size_t size_new = cur_p->size - removed;
|
||||
|
||||
// resizing now that penalized tokens are at the back
|
||||
cur_p->size = cur_p->size - found + 1;
|
||||
// shift tokens to remove the penalized ones
|
||||
for (size_t i = pos; i < size_new - pos; ++i) {
|
||||
cur_p->data[i] = cur_p->data[i + removed];
|
||||
}
|
||||
|
||||
if (cur_p->size < ctx->min_keep) cur_p->size = ctx->min_keep;
|
||||
// put the prenalized ones at the back
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
cur_p->data[cur_p->size - (1 + i)] = cur[i];
|
||||
}
|
||||
|
||||
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
|
||||
cur_p->size = size_new;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue