diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 02db44be2..ba8eb5c74 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1096,37 +1096,40 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data // in case it's not sorted/recalculated yet llama_sampler_softmax_impl(cur_p); - std::vector cur; - - int removed = -1; // to keep one out + std::vector top_tkns; int pos = 0; - // going through all candidates from back to front, easier to keep the last of probables - for (int i = (cur_p->size - 1); i >= 0; --i) { - if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) { - ++removed; - if (removed > 0) { - // .logits are used for sorting and calculating .p in llama_sample_softmax_impl - cur_p->data[i].logit = -999.0f; - cur.emplace_back(cur_p->data[i]); - pos = i; + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p >= ctx->threshold) { + if (cur_p->data[i].p <= ctx->threshold_max) { + top_tkns.emplace_back(cur_p->data[i]); + // capture position of the first penalizable + if (pos == -1) pos = i; } - } + } else break; } - if (removed > 0) { - size_t size_new = cur_p->size - removed; + // check if there are enough penalizable tokens + if (top_tkns.size() >= 2) { + // keep the least probable from top ones + top_tkns.pop_back(); - // shift tokens to remove the penalized ones + // define new size + size_t to_remove = top_tkns.size(); + size_t size_new = cur_p->size - to_remove; + + // shift tokens starting from pos for (size_t i = pos; i < size_new - pos; ++i) { - cur_p->data[i] = cur_p->data[i + removed]; + cur_p->data[i] = cur_p->data[i + to_remove]; } - // put the prenalized ones at the back - for (size_t i = 0; i < cur.size(); ++i) { - cur_p->data[cur_p->size - (1 + i)] = cur[i]; + // penalize top tokens and put them at the back + for (size_t i = 0; i < top_tkns.size(); ++i) { + top_tkns[i].logit = -999.0f; + cur_p->data[cur_p->size - (1 + i)] = top_tkns[i]; } + // resize if (size_new < ctx->min_keep) size_new = ctx->min_keep; cur_p->size = size_new; }