diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index ba8eb5c74..b152c329a 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1081,7 +1081,6 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data auto * ctx = (llama_sampler_xtc *) smpl->ctx; if (ctx->probability <= 0.0f - || ctx->threshold <= 0.0f || ctx->threshold >= 1.0f || ctx->threshold_max <= 0.0f || ctx->threshold_max <= ctx->threshold @@ -1096,43 +1095,27 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data // in case it's not sorted/recalculated yet llama_sampler_softmax_impl(cur_p); - std::vector top_tkns; - int pos = 0; + int pos_first = -1; + int pos_last = 0; for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].p >= ctx->threshold) { - if (cur_p->data[i].p <= ctx->threshold_max) { - top_tkns.emplace_back(cur_p->data[i]); - // capture position of the first penalizable - if (pos == -1) pos = i; - } - } else break; + if (cur_p->data[i].p - ctx->threshold >= -1e-5) { + if (cur_p->data[i].p - ctx->threshold_max > 1e-3) pos_first = i; + pos_last = i; + } else { + break; + } } - // check if there are enough penalizable tokens - if (top_tkns.size() >= 2) { - // keep the least probable from top ones - top_tkns.pop_back(); + size_t to_remove = pos_last - (1 + pos_first); - // define new size - size_t to_remove = top_tkns.size(); - size_t size_new = cur_p->size - to_remove; + if (to_remove < ctx->min_keep || to_remove < 1) return; - // shift tokens starting from pos - for (size_t i = pos; i < size_new - pos; ++i) { - cur_p->data[i] = cur_p->data[i + to_remove]; - } - - // penalize top tokens and put them at the back - for (size_t i = 0; i < top_tkns.size(); ++i) { - top_tkns[i].logit = -999.0f; - cur_p->data[cur_p->size - (1 + i)] = top_tkns[i]; - } - - // resize - if (size_new < ctx->min_keep) size_new = ctx->min_keep; - cur_p->size = size_new; + for (size_t i = pos_first + 1; i < cur_p->size - to_remove + 1; ++i) { + cur_p->data[i] = cur_p->data[i + to_remove]; } + + cur_p->size = cur_p->size - to_remove; } static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index cd6c61ba0..5716f7393 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -285,7 +285,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec } const int64_t t_end = ggml_time_us(); llama_sampler_free(cnstr); - printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter); + printf("%-47s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter); } #define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter)) @@ -332,10 +332,18 @@ int main(void) { test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f); test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f); + printf("XTC should:\n"); test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f, 1.00f); test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.1f}, 0.99f, 0.10f, 0.35f); + test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.20f, 1.00f); + test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 1.00f); test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.1f}, 0.99f, 0.10f, 0.25f); test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.35f); + printf("XTC should not:\n"); + test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.10f, 0.15f); + test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.25f); + test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 0.35f); + test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.40f, 1.00f); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);