llama : remove Tail-Free sampling (#10071)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-29 10:42:05 +02:00 committed by GitHub
parent 61715d5cc8
commit 8d8ff71536
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 15 additions and 172 deletions

View file

@ -105,16 +105,6 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
tester.check();
}
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & probs_expected, float z) {
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
tester.apply(llama_sampler_init_tail_free(z, 1));
DUMP(&tester.cur_p);
tester.check();
}
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
sampler_tester tester(probs, probs_expected);
@ -202,7 +192,6 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
for (auto s : samplers_sequence) {
switch (s){
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
case 'f': GGML_ABORT("tail_free test not implemented");
case 'y': GGML_ABORT("typical test not implemented");
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
@ -299,12 +288,11 @@ static void test_perf() {
data.emplace_back(llama_token_data{i, logit, 0.0f});
}
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
}
int main(void) {
@ -343,10 +331,6 @@ int main(void) {
printf("XTC should not:\n");
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);