Simplified algorithm and more tests
This commit is contained in:
parent
8110f783d1
commit
81a0c2603c
2 changed files with 23 additions and 32 deletions
|
@ -1081,7 +1081,6 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||
|
||||
if (ctx->probability <= 0.0f
|
||||
|| ctx->threshold <= 0.0f
|
||||
|| ctx->threshold >= 1.0f
|
||||
|| ctx->threshold_max <= 0.0f
|
||||
|| ctx->threshold_max <= ctx->threshold
|
||||
|
@ -1096,43 +1095,27 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
|||
// in case it's not sorted/recalculated yet
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
std::vector<llama_token_data> top_tkns;
|
||||
int pos = 0;
|
||||
int pos_first = -1;
|
||||
int pos_last = 0;
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].p >= ctx->threshold) {
|
||||
if (cur_p->data[i].p <= ctx->threshold_max) {
|
||||
top_tkns.emplace_back(cur_p->data[i]);
|
||||
// capture position of the first penalizable
|
||||
if (pos == -1) pos = i;
|
||||
}
|
||||
} else break;
|
||||
if (cur_p->data[i].p - ctx->threshold >= -1e-5) {
|
||||
if (cur_p->data[i].p - ctx->threshold_max > 1e-3) pos_first = i;
|
||||
pos_last = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// check if there are enough penalizable tokens
|
||||
if (top_tkns.size() >= 2) {
|
||||
// keep the least probable from top ones
|
||||
top_tkns.pop_back();
|
||||
size_t to_remove = pos_last - (1 + pos_first);
|
||||
|
||||
// define new size
|
||||
size_t to_remove = top_tkns.size();
|
||||
size_t size_new = cur_p->size - to_remove;
|
||||
if (to_remove < ctx->min_keep || to_remove < 1) return;
|
||||
|
||||
// shift tokens starting from pos
|
||||
for (size_t i = pos; i < size_new - pos; ++i) {
|
||||
cur_p->data[i] = cur_p->data[i + to_remove];
|
||||
}
|
||||
|
||||
// penalize top tokens and put them at the back
|
||||
for (size_t i = 0; i < top_tkns.size(); ++i) {
|
||||
top_tkns[i].logit = -999.0f;
|
||||
cur_p->data[cur_p->size - (1 + i)] = top_tkns[i];
|
||||
}
|
||||
|
||||
// resize
|
||||
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
|
||||
cur_p->size = size_new;
|
||||
for (size_t i = pos_first + 1; i < cur_p->size - to_remove + 1; ++i) {
|
||||
cur_p->data[i] = cur_p->data[i + to_remove];
|
||||
}
|
||||
|
||||
cur_p->size = cur_p->size - to_remove;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
||||
|
|
|
@ -285,7 +285,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
|
|||
}
|
||||
const int64_t t_end = ggml_time_us();
|
||||
llama_sampler_free(cnstr);
|
||||
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
||||
printf("%-47s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
||||
}
|
||||
|
||||
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
|
||||
|
@ -332,10 +332,18 @@ int main(void) {
|
|||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
||||
|
||||
printf("XTC should:\n");
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f, 1.00f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.1f}, 0.99f, 0.10f, 0.35f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.20f, 1.00f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 1.00f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.1f}, 0.99f, 0.10f, 0.25f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.35f);
|
||||
printf("XTC should not:\n");
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.10f, 0.15f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.25f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 0.35f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.40f, 1.00f);
|
||||
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue