parent
61715d5cc8
commit
8d8ff71536
16 changed files with 15 additions and 172 deletions
|
@ -105,16 +105,6 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
tester.check();
|
||||
}
|
||||
|
||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & probs_expected, float z) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
DUMP(&tester.cur_p);
|
||||
tester.apply(llama_sampler_init_tail_free(z, 1));
|
||||
DUMP(&tester.cur_p);
|
||||
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
|
@ -202,7 +192,6 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
|||
for (auto s : samplers_sequence) {
|
||||
switch (s){
|
||||
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
|
||||
case 'f': GGML_ABORT("tail_free test not implemented");
|
||||
case 'y': GGML_ABORT("typical test not implemented");
|
||||
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
|
||||
case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
|
||||
|
@ -299,12 +288,11 @@ static void test_perf() {
|
|||
data.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
BENCH(llama_sampler_init_top_k (40), data, 32);
|
||||
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
||||
BENCH(llama_sampler_init_top_k (40), data, 32);
|
||||
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
|
@ -343,10 +331,6 @@ int main(void) {
|
|||
printf("XTC should not:\n");
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
|
||||
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
|
||||
|
||||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue