Standardize top_k sorting

This commit is contained in:
kalomaze 2024-01-22 05:11:50 -06:00
parent 4779d994fc
commit feea528add

View file

@ -7973,13 +7973,17 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
} }
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
if (candidates->sorted) {
candidates->size = k;
return;
}
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
k = std::max(k, (int) min_keep); k = std::max(k, (int) min_keep);
k = std::min(k, (int) candidates->size); k = std::min(k, (int) candidates->size);
// Sort scores in descending order // Sort scores in descending order
if (!candidates->sorted) {
auto comp = [](const llama_token_data & a, const llama_token_data & b) { auto comp = [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit; return a.logit > b.logit;
}; };
@ -7989,7 +7993,6 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
} }
candidates->sorted = true; candidates->sorted = true;
}
candidates->size = k; candidates->size = k;
if (ctx) { if (ctx) {