Swapped sorting for a custom algorithm
Shifts tokens to remove the penalized ones, then puts the penalized at the back. Should make `min_keep` still viable.
This commit is contained in:
parent
59e8e63e68
commit
63e60deda3
1 changed files with 22 additions and 11 deletions
|
@ -1096,28 +1096,39 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
// in case it's not sorted/recalculated yet
|
// in case it's not sorted/recalculated yet
|
||||||
llama_sampler_softmax_impl(cur_p);
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
int found = 0;
|
std::vector<llama_token_data> cur;
|
||||||
|
|
||||||
|
int removed = -1; // to keep one out
|
||||||
|
int pos = 0;
|
||||||
|
|
||||||
// going through all candidates from back to front, easier to keep the last of probables
|
// going through all candidates from back to front, easier to keep the last of probables
|
||||||
for (int i = (cur_p->size - 1); i >= 0; --i) {
|
for (int i = (cur_p->size - 1); i >= 0; --i) {
|
||||||
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
|
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
|
||||||
++found;
|
++removed;
|
||||||
if (found > 1) {
|
if (removed > 0) {
|
||||||
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
|
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
|
||||||
cur_p->data[i].logit = -999.0f;
|
cur_p->data[i].logit = -999.0f;
|
||||||
|
cur.emplace_back(cur_p->data[i]);
|
||||||
|
pos = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (found > 1) {
|
if (removed > 0) {
|
||||||
// sorting with new logits, ex-last probable will be the first anyway
|
size_t size_new = cur_p->size - removed;
|
||||||
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
|
||||||
return a.logit > b.logit;
|
|
||||||
});
|
|
||||||
|
|
||||||
// resizing now that penalized tokens are at the back
|
// shift tokens to remove the penalized ones
|
||||||
cur_p->size = cur_p->size - found + 1;
|
for (size_t i = pos; i < size_new - pos; ++i) {
|
||||||
|
cur_p->data[i] = cur_p->data[i + removed];
|
||||||
|
}
|
||||||
|
|
||||||
if (cur_p->size < ctx->min_keep) cur_p->size = ctx->min_keep;
|
// put the prenalized ones at the back
|
||||||
|
for (size_t i = 0; i < cur.size(); ++i) {
|
||||||
|
cur_p->data[cur_p->size - (1 + i)] = cur[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
|
||||||
|
cur_p->size = size_new;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue