Algorithm rework

1. Scan token from top till the first non-penalizable
2. Remove the last captured token (the least probable above threshold)
3. Shift all tokens to override the remaining penalizable
4. Penalize and put them at the the bottom.
This commit is contained in:
MaggotHATE 2024-10-06 16:15:12 +05:00 committed by GitHub
parent 094caea359
commit 39940e5fa3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1096,37 +1096,40 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);
std::vector<llama_token_data> cur;
int removed = -1; // to keep one out
std::vector<llama_token_data> top_tkns;
int pos = 0;
// going through all candidates from back to front, easier to keep the last of probables
for (int i = (cur_p->size - 1); i >= 0; --i) {
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
++removed;
if (removed > 0) {
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
cur_p->data[i].logit = -999.0f;
cur.emplace_back(cur_p->data[i]);
pos = i;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].p >= ctx->threshold) {
if (cur_p->data[i].p <= ctx->threshold_max) {
top_tkns.emplace_back(cur_p->data[i]);
// capture position of the first penalizable
if (pos == -1) pos = i;
}
}
} else break;
}
if (removed > 0) {
size_t size_new = cur_p->size - removed;
// check if there are enough penalizable tokens
if (top_tkns.size() >= 2) {
// keep the least probable from top ones
top_tkns.pop_back();
// shift tokens to remove the penalized ones
// define new size
size_t to_remove = top_tkns.size();
size_t size_new = cur_p->size - to_remove;
// shift tokens starting from pos
for (size_t i = pos; i < size_new - pos; ++i) {
cur_p->data[i] = cur_p->data[i + removed];
cur_p->data[i] = cur_p->data[i + to_remove];
}
// put the prenalized ones at the back
for (size_t i = 0; i < cur.size(); ++i) {
cur_p->data[cur_p->size - (1 + i)] = cur[i];
// penalize top tokens and put them at the back
for (size_t i = 0; i < top_tkns.size(); ++i) {
top_tkns[i].logit = -999.0f;
cur_p->data[cur_p->size - (1 + i)] = top_tkns[i];
}
// resize
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
cur_p->size = size_new;
}