Standardize top_k sorting
This commit is contained in:
parent
4779d994fc
commit
feea528add
1 changed files with 14 additions and 11 deletions
23
llama.cpp
23
llama.cpp
|
@ -7973,23 +7973,26 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
|
|||
}
|
||||
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
if (candidates->sorted) {
|
||||
candidates->size = k;
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
k = std::min(k, (int) candidates->size);
|
||||
|
||||
// Sort scores in descending order
|
||||
if (!candidates->sorted) {
|
||||
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
};
|
||||
if (k == (int) candidates->size) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, comp);
|
||||
} else {
|
||||
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
|
||||
}
|
||||
candidates->sorted = true;
|
||||
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
};
|
||||
if (k == (int) candidates->size) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, comp);
|
||||
} else {
|
||||
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
|
||||
}
|
||||
candidates->sorted = true;
|
||||
candidates->size = k;
|
||||
|
||||
if (ctx) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue