Fixes and tests

This commit is contained in:
MaggotHATE 2024-10-26 12:21:46 +05:00
parent 070f9546f6
commit 48b715da28
2 changed files with 21 additions and 1 deletions

View file

@ -1202,7 +1202,9 @@ static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*sm
static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
if (ctx->k <= 0 || ctx->k_set == true) {
if (ctx->k_set == true
|| ctx->k <= 0
|| ctx->k >= (int) cur_p->size) {
return;
}

View file

@ -83,6 +83,17 @@ static void test_temp_ext(const std::vector<float> & probs, const std::vector<fl
tester.check();
}
static void test_k_shift(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
tester.apply(llama_sampler_init_k_shift(k));
tester.apply(llama_sampler_init_dist (0));
DUMP(&tester.cur_p);
tester.check();
}
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
sampler_tester tester(probs, probs_expected);
@ -299,6 +310,7 @@ static void test_perf() {
data.emplace_back(llama_token_data{i, logit, 0.0f});
}
BENCH(llama_sampler_init_k_shift (10), data, 32);
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
@ -316,6 +328,12 @@ int main(void) {
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 3);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.66666f, 0.33333f}, 2);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.5f, 0.33333f, 0.16666f}, 1);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);