Standardize top_k sorting

This commit is contained in:
kalomaze 2024-01-22 05:11:50 -06:00
parent 4779d994fc
commit feea528add

View file

@ -7973,13 +7973,17 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
}
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
if (candidates->sorted) {
candidates->size = k;
return;
}
const int64_t t_start_sample_us = ggml_time_us();
k = std::max(k, (int) min_keep);
k = std::min(k, (int) candidates->size);
// Sort scores in descending order
if (!candidates->sorted) {
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
};
@ -7989,7 +7993,6 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
}
candidates->sorted = true;
}
candidates->size = k;
if (ctx) {