From 838d58dc3273c22ffa5f75633f4b9be7f09bb2d3 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 21:08:26 -0500 Subject: [PATCH] Min P disabled if set to 1.0 or 0, otherwise Top P --- common/common.cpp | 2 +- common/sampling.cpp | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b7ed4827d..fe15d03a9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -685,7 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); + printf(" --min-p N min-p sampling (default: %.2f, 1.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); diff --git a/common/sampling.cpp b/common/sampling.cpp index 673d67a6d..f5507dfca 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -190,8 +190,11 @@ llama_token llama_sampling_sample( llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); - llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); + if (min_p != 1.0 && min_p != 0.0) { + llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); + } else { + llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); + } llama_sample_temp (ctx_main, &cur_p, temp); id = llama_sample_token(ctx_main, &cur_p);