Simplified algorithm and more tests
This commit is contained in:
parent
8110f783d1
commit
81a0c2603c
2 changed files with 23 additions and 32 deletions
|
@ -1081,7 +1081,6 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||||
|
|
||||||
if (ctx->probability <= 0.0f
|
if (ctx->probability <= 0.0f
|
||||||
|| ctx->threshold <= 0.0f
|
|
||||||
|| ctx->threshold >= 1.0f
|
|| ctx->threshold >= 1.0f
|
||||||
|| ctx->threshold_max <= 0.0f
|
|| ctx->threshold_max <= 0.0f
|
||||||
|| ctx->threshold_max <= ctx->threshold
|
|| ctx->threshold_max <= ctx->threshold
|
||||||
|
@ -1096,43 +1095,27 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
// in case it's not sorted/recalculated yet
|
// in case it's not sorted/recalculated yet
|
||||||
llama_sampler_softmax_impl(cur_p);
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
std::vector<llama_token_data> top_tkns;
|
int pos_first = -1;
|
||||||
int pos = 0;
|
int pos_last = 0;
|
||||||
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
if (cur_p->data[i].p >= ctx->threshold) {
|
if (cur_p->data[i].p - ctx->threshold >= -1e-5) {
|
||||||
if (cur_p->data[i].p <= ctx->threshold_max) {
|
if (cur_p->data[i].p - ctx->threshold_max > 1e-3) pos_first = i;
|
||||||
top_tkns.emplace_back(cur_p->data[i]);
|
pos_last = i;
|
||||||
// capture position of the first penalizable
|
} else {
|
||||||
if (pos == -1) pos = i;
|
break;
|
||||||
}
|
}
|
||||||
} else break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if there are enough penalizable tokens
|
size_t to_remove = pos_last - (1 + pos_first);
|
||||||
if (top_tkns.size() >= 2) {
|
|
||||||
// keep the least probable from top ones
|
|
||||||
top_tkns.pop_back();
|
|
||||||
|
|
||||||
// define new size
|
if (to_remove < ctx->min_keep || to_remove < 1) return;
|
||||||
size_t to_remove = top_tkns.size();
|
|
||||||
size_t size_new = cur_p->size - to_remove;
|
|
||||||
|
|
||||||
// shift tokens starting from pos
|
for (size_t i = pos_first + 1; i < cur_p->size - to_remove + 1; ++i) {
|
||||||
for (size_t i = pos; i < size_new - pos; ++i) {
|
cur_p->data[i] = cur_p->data[i + to_remove];
|
||||||
cur_p->data[i] = cur_p->data[i + to_remove];
|
|
||||||
}
|
|
||||||
|
|
||||||
// penalize top tokens and put them at the back
|
|
||||||
for (size_t i = 0; i < top_tkns.size(); ++i) {
|
|
||||||
top_tkns[i].logit = -999.0f;
|
|
||||||
cur_p->data[cur_p->size - (1 + i)] = top_tkns[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// resize
|
|
||||||
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
|
|
||||||
cur_p->size = size_new;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cur_p->size = cur_p->size - to_remove;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
||||||
|
|
|
@ -285,7 +285,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
|
||||||
}
|
}
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
llama_sampler_free(cnstr);
|
llama_sampler_free(cnstr);
|
||||||
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
printf("%-47s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
|
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
|
||||||
|
@ -332,10 +332,18 @@ int main(void) {
|
||||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
||||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
||||||
|
|
||||||
|
printf("XTC should:\n");
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f, 1.00f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f, 1.00f);
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.1f}, 0.99f, 0.10f, 0.35f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.1f}, 0.99f, 0.10f, 0.35f);
|
||||||
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.20f, 1.00f);
|
||||||
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 1.00f);
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.1f}, 0.99f, 0.10f, 0.25f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.1f}, 0.99f, 0.10f, 0.25f);
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.35f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.35f);
|
||||||
|
printf("XTC should not:\n");
|
||||||
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.10f, 0.15f);
|
||||||
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.25f);
|
||||||
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 0.35f);
|
||||||
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.40f, 1.00f);
|
||||||
|
|
||||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue