From 63e60deda33bb06accddbe5cb2806e6d299c7f91 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Sat, 5 Oct 2024 23:27:36 +0500 Subject: [PATCH] Swapped sorting for a custom algorithm Shifts tokens to remove the penalized ones, then puts the penalized at the back. Should make `min_keep` still viable. --- src/llama-sampling.cpp | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8349b3162..02db44be2 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1096,28 +1096,39 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data // in case it's not sorted/recalculated yet llama_sampler_softmax_impl(cur_p); - int found = 0; + std::vector cur; + + int removed = -1; // to keep one out + int pos = 0; + // going through all candidates from back to front, easier to keep the last of probables for (int i = (cur_p->size - 1); i >= 0; --i) { if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) { - ++found; - if (found > 1) { + ++removed; + if (removed > 0) { // .logits are used for sorting and calculating .p in llama_sample_softmax_impl cur_p->data[i].logit = -999.0f; + cur.emplace_back(cur_p->data[i]); + pos = i; } } } - if (found > 1) { - // sorting with new logits, ex-last probable will be the first anyway - std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + if (removed > 0) { + size_t size_new = cur_p->size - removed; - // resizing now that penalized tokens are at the back - cur_p->size = cur_p->size - found + 1; + // shift tokens to remove the penalized ones + for (size_t i = pos; i < size_new - pos; ++i) { + cur_p->data[i] = cur_p->data[i + removed]; + } - if (cur_p->size < ctx->min_keep) cur_p->size = ctx->min_keep; + // put the prenalized ones at the back + for (size_t i = 0; i < cur.size(); ++i) { + cur_p->data[cur_p->size - (1 + i)] = cur[i]; + } + + if (size_new < ctx->min_keep) size_new = ctx->min_keep; + cur_p->size = size_new; } }