From 00006da253ab2d9535b7e829a7022907ad8dd9a3 Mon Sep 17 00:00:00 2001 From: jhen Date: Fri, 6 Oct 2023 07:11:18 +0800 Subject: [PATCH] common : use n_probs for temperature sampling --- common/common.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6b9b4695c..186f5b268 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1020,10 +1020,11 @@ llama_token llama_sample_token( id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling - llama_sample_top_k (ctx, &cur_p, top_k, 1); - llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); - llama_sample_typical (ctx, &cur_p, typical_p, 1); - llama_sample_top_p (ctx, &cur_p, top_p, 1); + size_t min_keep = std::max(1, params.n_probs); + llama_sample_top_k (ctx, &cur_p, top_k, min_keep); + llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); + llama_sample_typical (ctx, &cur_p, typical_p, min_keep); + llama_sample_top_p (ctx, &cur_p, top_p, min_keep); llama_sample_temp(ctx, &cur_p, temp); {