Algorithm rework
1. Scan token from top till the first non-penalizable 2. Remove the last captured token (the least probable above threshold) 3. Shift all tokens to override the remaining penalizable 4. Penalize and put them at the the bottom.
This commit is contained in:
parent
094caea359
commit
39940e5fa3
1 changed files with 23 additions and 20 deletions
|
@ -1096,37 +1096,40 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|||
// in case it's not sorted/recalculated yet
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
int removed = -1; // to keep one out
|
||||
std::vector<llama_token_data> top_tkns;
|
||||
int pos = 0;
|
||||
|
||||
// going through all candidates from back to front, easier to keep the last of probables
|
||||
for (int i = (cur_p->size - 1); i >= 0; --i) {
|
||||
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
|
||||
++removed;
|
||||
if (removed > 0) {
|
||||
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
|
||||
cur_p->data[i].logit = -999.0f;
|
||||
cur.emplace_back(cur_p->data[i]);
|
||||
pos = i;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].p >= ctx->threshold) {
|
||||
if (cur_p->data[i].p <= ctx->threshold_max) {
|
||||
top_tkns.emplace_back(cur_p->data[i]);
|
||||
// capture position of the first penalizable
|
||||
if (pos == -1) pos = i;
|
||||
}
|
||||
}
|
||||
} else break;
|
||||
}
|
||||
|
||||
if (removed > 0) {
|
||||
size_t size_new = cur_p->size - removed;
|
||||
// check if there are enough penalizable tokens
|
||||
if (top_tkns.size() >= 2) {
|
||||
// keep the least probable from top ones
|
||||
top_tkns.pop_back();
|
||||
|
||||
// shift tokens to remove the penalized ones
|
||||
// define new size
|
||||
size_t to_remove = top_tkns.size();
|
||||
size_t size_new = cur_p->size - to_remove;
|
||||
|
||||
// shift tokens starting from pos
|
||||
for (size_t i = pos; i < size_new - pos; ++i) {
|
||||
cur_p->data[i] = cur_p->data[i + removed];
|
||||
cur_p->data[i] = cur_p->data[i + to_remove];
|
||||
}
|
||||
|
||||
// put the prenalized ones at the back
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
cur_p->data[cur_p->size - (1 + i)] = cur[i];
|
||||
// penalize top tokens and put them at the back
|
||||
for (size_t i = 0; i < top_tkns.size(); ++i) {
|
||||
top_tkns[i].logit = -999.0f;
|
||||
cur_p->data[cur_p->size - (1 + i)] = top_tkns[i];
|
||||
}
|
||||
|
||||
// resize
|
||||
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
|
||||
cur_p->size = size_new;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue