add tests
This commit is contained in:
parent
3722c729b8
commit
c795d8b82b
2 changed files with 13 additions and 9 deletions
|
@ -1674,7 +1674,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||
/* .penalty_repeat = */ penalty_repeat,
|
||||
/* .penalty_freq = */ penalty_freq,
|
||||
/* .penalty_present = */ penalty_present,
|
||||
/* .penalty_repeat_sigmoid_growth */ penalty_repeat_sigmoid_growth,
|
||||
/* .penalty_repeat_sigmoid_growth = */ penalty_repeat_sigmoid_growth,
|
||||
/* .penalize_nl = */ penalize_nl,
|
||||
/* .ignore_eos = */ ignore_eos,
|
||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||
|
|
|
@ -134,7 +134,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||
|
||||
static void test_penalties(
|
||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
||||
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
|
||||
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_repeat_sigmoid_growth
|
||||
) {
|
||||
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||
|
||||
|
@ -149,7 +149,7 @@ static void test_penalties(
|
|||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
|
||||
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, 0.0, false, false);
|
||||
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_repeat_sigmoid_growth, false, false);
|
||||
|
||||
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||
llama_sampler_accept(sampler, last_tokens[i]);
|
||||
|
@ -316,13 +316,17 @@ int main(void) {
|
|||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 0.0f, 0.0f, 10.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.263353f, 0.263353f, 0.201890f, 0.153630f, 0.117775f}, 1.5f, 0.0f, 0.0f, 1.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.290188f, 0.246533f, 0.182452f, 0.140414, 0.140414f}, 0.5f, 0.0f, 0.0f, -0.5f);
|
||||
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 0.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 0.0f);
|
||||
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 0.0f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue