diff --git a/llama.cpp b/llama.cpp index 3fef34d18..1ce209fa6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8001,20 +8001,17 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - if (k < 200) { + if (k <= 128) { std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); } else { - constexpr int nbuckets = 128; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); constexpr float bucker_inter = -bucket_low * bucket_scale; - std::vector tmp_tokens(candidates->size); std::vector bucket_idx(candidates->size); std::vector histo(nbuckets, 0); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets); for (int i = 0; i < (int)candidates->size; ++i) { const float val = candidates->data[i].logit; @@ -8023,14 +8020,19 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can bucket_idx[i] = ib; ++histo[ib]; } - auto ptr = tmp_tokens.data(); int nhave = 0; int ib = nbuckets - 1; for ( ; ib >= 0; --ib) { nhave += histo[ib]; - bucket_ptrs.push_back(ptr); if (nhave >= k) break; - ptr += histo[ib]; + } + std::vector tmp_tokens(nhave); + auto ptr = tmp_tokens.data(); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; } for (int i = 0; i < (int)candidates->size; ++i) { int j = bucket_idx[i];