Standardize top_k sorting

This commit is contained in:
kalomaze 2024-01-22 05:11:50 -06:00
parent 4779d994fc
commit feea528add

View file

@ -7973,23 +7973,26 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
} }
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
if (candidates->sorted) {
candidates->size = k;
return;
}
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
k = std::max(k, (int) min_keep); k = std::max(k, (int) min_keep);
k = std::min(k, (int) candidates->size); k = std::min(k, (int) candidates->size);
// Sort scores in descending order // Sort scores in descending order
if (!candidates->sorted) { auto comp = [](const llama_token_data & a, const llama_token_data & b) {
auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit;
return a.logit > b.logit; };
}; if (k == (int) candidates->size) {
if (k == (int) candidates->size) { std::sort(candidates->data, candidates->data + candidates->size, comp);
std::sort(candidates->data, candidates->data + candidates->size, comp); } else {
} else { std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
}
candidates->sorted = true;
} }
candidates->sorted = true;
candidates->size = k; candidates->size = k;
if (ctx) { if (ctx) {