diff --git a/llama.cpp b/llama.cpp index 59ab8f868..e128cea6f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7973,23 +7973,26 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c } void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - const int64_t t_start_sample_us = ggml_time_us(); + if (candidates->sorted) { + candidates->size = k; + return; + } + const int64_t t_start_sample_us = ggml_time_us(); + k = std::max(k, (int) min_keep); k = std::min(k, (int) candidates->size); // Sort scores in descending order - if (!candidates->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k == (int) candidates->size) { - std::sort(candidates->data, candidates->data + candidates->size, comp); - } else { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); - } - candidates->sorted = true; + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k == (int) candidates->size) { + std::sort(candidates->data, candidates->data + candidates->size, comp); + } else { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); } + candidates->sorted = true; candidates->size = k; if (ctx) {