From 3bf3a968b6f70013ea94163243080202ddd9c66e Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 28 Apr 2023 20:36:53 +0300 Subject: [PATCH] Tests --- tests/test-sampling.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 3f3d5d174..c89b569fe 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -7,6 +7,9 @@ #include #include +#undef assert +#define assert(__expr) do { if (!(__expr)) { printf("%s:%d (%s) %s\n", __FILE__, __LINE__, __func__, #__expr); exit(1); } } while(0) + void dump(const llama_token_data_array * candidates) { for (size_t i = 0; i < candidates->size; i++) { printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); @@ -53,13 +56,14 @@ void test_top_p(const std::vector & probs, } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sample_softmax(nullptr, &candidates_p); // DUMP(&candidates_p); llama_sample_top_p(nullptr, &candidates_p, p); // DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -82,7 +86,7 @@ void test_tfs(const std::vector & probs, assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -105,7 +109,7 @@ void test_typical(const std::vector & probs, assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -163,7 +167,7 @@ void test_frequency_presence_penalty( assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -182,9 +186,9 @@ int main(void) { test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5); test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0, 0.25, 0.25, 0.25, 0.25}, 50.0); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0, 0, 0, 0.5, 0.5}, 50.0); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5, 0.5}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.25, 0.25, 0.25, 0.25, 0}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.5, 0.5, 0, 0, 0}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.5, 0.5, 0, 0, 0}, 50.0); test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0); test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0);