added tests and fixed nsigma impl

This commit is contained in:
VJHack 2025-01-13 17:48:35 -06:00
parent 8fb681bf9a
commit 54ef105c85
2 changed files with 28 additions and 13 deletions

View file

@ -182,6 +182,17 @@ static void test_dry(
tester.check();
}
static void test_top_n_sigma(const std::vector<float> & probs, const std::vector<float> & probs_expected, int n) {
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
tester.apply(llama_sampler_init_top_n_sigma(n));
tester.apply(llama_sampler_init_dist (0));
DUMP(&tester.cur_p);
tester.check();
}
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
) {
sampler_tester tester(n_vocab);
@ -349,6 +360,14 @@ int main(void) {
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3);
// test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
// test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
// test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);