Fixes and tests
This commit is contained in:
parent
070f9546f6
commit
48b715da28
2 changed files with 21 additions and 1 deletions
|
@ -1202,7 +1202,9 @@ static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*sm
|
|||
static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
|
||||
|
||||
if (ctx->k <= 0 || ctx->k_set == true) {
|
||||
if (ctx->k_set == true
|
||||
|| ctx->k <= 0
|
||||
|| ctx->k >= (int) cur_p->size) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -83,6 +83,17 @@ static void test_temp_ext(const std::vector<float> & probs, const std::vector<fl
|
|||
tester.check();
|
||||
}
|
||||
|
||||
static void test_k_shift(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
DUMP(&tester.cur_p);
|
||||
tester.apply(llama_sampler_init_k_shift(k));
|
||||
tester.apply(llama_sampler_init_dist (0));
|
||||
DUMP(&tester.cur_p);
|
||||
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
|
@ -299,6 +310,7 @@ static void test_perf() {
|
|||
data.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
BENCH(llama_sampler_init_k_shift (10), data, 32);
|
||||
BENCH(llama_sampler_init_top_k (40), data, 32);
|
||||
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
||||
|
@ -316,6 +328,12 @@ int main(void) {
|
|||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
||||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
||||
|
||||
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 3);
|
||||
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.66666f, 0.33333f}, 2);
|
||||
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.5f, 0.33333f, 0.16666f}, 1);
|
||||
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
||||
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue