From c795d8b82bf2d6d61e49b3179ec9562641902a53 Mon Sep 17 00:00:00 2001 From: ZXED Date: Wed, 25 Sep 2024 19:57:52 +0300 Subject: [PATCH] add tests --- src/llama-sampling.cpp | 2 +- tests/test-sampling.cpp | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 08ee2110b..fd2e3dbef 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1674,7 +1674,7 @@ struct llama_sampler * llama_sampler_init_penalties( /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, /* .penalty_present = */ penalty_present, - /* .penalty_repeat_sigmoid_growth */ penalty_repeat_sigmoid_growth, + /* .penalty_repeat_sigmoid_growth = */ penalty_repeat_sigmoid_growth, /* .penalize_nl = */ penalize_nl, /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n), diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index ef458bc64..c4574bf59 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -134,7 +134,7 @@ static void test_typical(const std::vector & probs, const std::vector & probs, const std::vector & last_tokens, - const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence + const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_repeat_sigmoid_growth ) { GGML_ASSERT(probs.size() == expected_probs.size()); @@ -149,7 +149,7 @@ static void test_penalties( llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, 0.0, false, false); + auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_repeat_sigmoid_growth, false, false); for (size_t i = 0; i < last_tokens.size(); i++) { llama_sampler_accept(sampler, last_tokens[i]); @@ -316,13 +316,17 @@ int main(void) { test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 0.0f, 0.0f, 10.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.263353f, 0.263353f, 0.201890f, 0.153630f, 0.117775f}, 1.5f, 0.0f, 0.0f, 1.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.290188f, 0.246533f, 0.182452f, 0.140414, 0.140414f}, 0.5f, 0.0f, 0.0f, -0.5f); + + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 0.0f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);