Simplified counter by checking candidates size
+ fixed 0.0 default for min_p
This commit is contained in:
parent
49b68e8226
commit
6f7cdec38a
2 changed files with 4 additions and 7 deletions
|
@ -14,7 +14,7 @@ typedef struct llama_sampling_params {
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float min_p = 1.0f; // 1.0 (or 0.0) = disabled
|
float min_p = 0.0f; // 1.0 (or 0.0) = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
float typical_p = 1.00f; // 1.0 = disabled
|
float typical_p = 1.00f; // 1.0 = disabled
|
||||||
float temp = 0.80f; // 1.0 = disabled
|
float temp = 0.80f; // 1.0 = disabled
|
||||||
|
|
|
@ -7377,18 +7377,15 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
||||||
std::vector<llama_token_data> filtered_candidates;
|
std::vector<llama_token_data> filtered_candidates;
|
||||||
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
|
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
|
||||||
|
|
||||||
size_t kept_count = 0; // Counter for how many tokens are kept
|
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
// If a token's probability is above the threshold, we keep it.
|
// If a token's probability is above the threshold or if we haven't kept enough tokens yet
|
||||||
if (candidates->data[i].p >= multiplied_min_p) {
|
if (candidates->data[i].p >= multiplied_min_p || filtered_candidates.size() < min_keep) {
|
||||||
filtered_candidates.push_back(candidates->data[i]);
|
filtered_candidates.push_back(candidates->data[i]);
|
||||||
kept_count++; // Increment the counter
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If not enough candidates meet the threshold, take the top 'min_keep' ones
|
// If not enough candidates meet the threshold, take the top 'min_keep' ones
|
||||||
if (kept_count < min_keep) {
|
if (filtered_candidates.size() < min_keep) {
|
||||||
std::sort(candidates->data, candidates->data + candidates->size,
|
std::sort(candidates->data, candidates->data + candidates->size,
|
||||||
[](const llama_token_data & a, const llama_token_data & b) {
|
[](const llama_token_data & a, const llama_token_data & b) {
|
||||||
return a.p > b.p; // Sort by probability in descending order
|
return a.p > b.p; // Sort by probability in descending order
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue