diff --git a/common/common.cpp b/common/common.cpp index b7ed4827d..fe15d03a9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -685,7 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); + printf(" --min-p N min-p sampling (default: %.2f, 1.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); diff --git a/common/sampling.cpp b/common/sampling.cpp index 673d67a6d..f5507dfca 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -190,8 +190,11 @@ llama_token llama_sampling_sample( llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); - llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); + if (min_p != 1.0 && min_p != 0.0) { + llama_sample_min_p(ctx_main, &cur_p, min_p, min_keep); + } else { + llama_sample_top_p(ctx_main, &cur_p, top_p, min_keep); + } llama_sample_temp (ctx_main, &cur_p, temp); id = llama_sample_token(ctx_main, &cur_p);