Swapped sorting for a custom algorithm

Shifts tokens to remove the penalized ones, then puts the penalized at the back. Should make `min_keep` still viable.
This commit is contained in:
MaggotHATE 2024-10-05 23:27:36 +05:00 committed by GitHub
parent 59e8e63e68
commit 63e60deda3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1096,28 +1096,39 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);
int found = 0;
std::vector<llama_token_data> cur;
int removed = -1; // to keep one out
int pos = 0;
// going through all candidates from back to front, easier to keep the last of probables
for (int i = (cur_p->size - 1); i >= 0; --i) {
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
++found;
if (found > 1) {
++removed;
if (removed > 0) {
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
cur_p->data[i].logit = -999.0f;
cur.emplace_back(cur_p->data[i]);
pos = i;
}
}
}
if (found > 1) {
// sorting with new logits, ex-last probable will be the first anyway
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
if (removed > 0) {
size_t size_new = cur_p->size - removed;
// resizing now that penalized tokens are at the back
cur_p->size = cur_p->size - found + 1;
// shift tokens to remove the penalized ones
for (size_t i = pos; i < size_new - pos; ++i) {
cur_p->data[i] = cur_p->data[i + removed];
}
if (cur_p->size < ctx->min_keep) cur_p->size = ctx->min_keep;
// put the prenalized ones at the back
for (size_t i = 0; i < cur.size(); ++i) {
cur_p->data[cur_p->size - (1 + i)] = cur[i];
}
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
cur_p->size = size_new;
}
}