Simplified algorithm and more tests

This commit is contained in:
MaggotHATE 2024-10-08 18:38:43 +05:00 committed by GitHub
parent 8110f783d1
commit 81a0c2603c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 32 deletions

View file

@ -1081,7 +1081,6 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
if (ctx->probability <= 0.0f
|| ctx->threshold <= 0.0f
|| ctx->threshold >= 1.0f
|| ctx->threshold_max <= 0.0f
|| ctx->threshold_max <= ctx->threshold
@ -1096,43 +1095,27 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);
std::vector<llama_token_data> top_tkns;
int pos = 0;
int pos_first = -1;
int pos_last = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].p >= ctx->threshold) {
if (cur_p->data[i].p <= ctx->threshold_max) {
top_tkns.emplace_back(cur_p->data[i]);
// capture position of the first penalizable
if (pos == -1) pos = i;
}
} else break;
if (cur_p->data[i].p - ctx->threshold >= -1e-5) {
if (cur_p->data[i].p - ctx->threshold_max > 1e-3) pos_first = i;
pos_last = i;
} else {
break;
}
}
// check if there are enough penalizable tokens
if (top_tkns.size() >= 2) {
// keep the least probable from top ones
top_tkns.pop_back();
size_t to_remove = pos_last - (1 + pos_first);
// define new size
size_t to_remove = top_tkns.size();
size_t size_new = cur_p->size - to_remove;
if (to_remove < ctx->min_keep || to_remove < 1) return;
// shift tokens starting from pos
for (size_t i = pos; i < size_new - pos; ++i) {
cur_p->data[i] = cur_p->data[i + to_remove];
}
// penalize top tokens and put them at the back
for (size_t i = 0; i < top_tkns.size(); ++i) {
top_tkns[i].logit = -999.0f;
cur_p->data[cur_p->size - (1 + i)] = top_tkns[i];
}
// resize
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
cur_p->size = size_new;
for (size_t i = pos_first + 1; i < cur_p->size - to_remove + 1; ++i) {
cur_p->data[i] = cur_p->data[i + to_remove];
}
cur_p->size = cur_p->size - to_remove;
}
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {

View file

@ -285,7 +285,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
}
const int64_t t_end = ggml_time_us();
llama_sampler_free(cnstr);
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
printf("%-47s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
}
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
@ -332,10 +332,18 @@ int main(void) {
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
printf("XTC should:\n");
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f, 1.00f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.1f}, 0.99f, 0.10f, 0.35f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.20f, 1.00f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 1.00f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.1f}, 0.99f, 0.10f, 0.25f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.35f);
printf("XTC should not:\n");
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.10f, 0.15f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.25f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 0.35f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.40f, 1.00f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);