add tests

This commit is contained in:
ZXED 2024-09-25 19:57:52 +03:00
parent 3722c729b8
commit c795d8b82b
No known key found for this signature in database
GPG key ID: 637FB44813DCFD66
2 changed files with 13 additions and 9 deletions

View file

@ -1674,7 +1674,7 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_repeat = */ penalty_repeat,
/* .penalty_freq = */ penalty_freq,
/* .penalty_present = */ penalty_present,
/* .penalty_repeat_sigmoid_growth */ penalty_repeat_sigmoid_growth,
/* .penalty_repeat_sigmoid_growth = */ penalty_repeat_sigmoid_growth,
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),

View file

@ -134,7 +134,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
static void test_penalties(
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_repeat_sigmoid_growth
) {
GGML_ASSERT(probs.size() == expected_probs.size());
@ -149,7 +149,7 @@ static void test_penalties(
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, 0.0, false, false);
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_repeat_sigmoid_growth, false, false);
for (size_t i = 0; i < last_tokens.size(); i++) {
llama_sampler_accept(sampler, last_tokens[i]);
@ -316,13 +316,17 @@ int main(void) {
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 0.0f, 0.0f, 10.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.263353f, 0.263353f, 0.201890f, 0.153630f, 0.117775f}, 1.5f, 0.0f, 0.0f, 1.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.290188f, 0.246533f, 0.182452f, 0.140414, 0.140414f}, 0.5f, 0.0f, 0.0f, -0.5f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 0.0f);
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 0.0f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);