Simplified counter by checking candidates size

+ fixed 0.0 default for min_p
This commit is contained in:
kalomaze 2023-10-28 23:37:18 -05:00
parent 49b68e8226
commit 6f7cdec38a
2 changed files with 4 additions and 7 deletions

View file

@ -14,7 +14,7 @@ typedef struct llama_sampling_params {
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 1.0f; // 1.0 (or 0.0) = disabled
float min_p = 0.0f; // 1.0 (or 0.0) = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled

View file

@ -7377,18 +7377,15 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
std::vector<llama_token_data> filtered_candidates;
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
size_t kept_count = 0; // Counter for how many tokens are kept
for (size_t i = 0; i < candidates->size; ++i) {
// If a token's probability is above the threshold, we keep it.
if (candidates->data[i].p >= multiplied_min_p) {
// If a token's probability is above the threshold or if we haven't kept enough tokens yet
if (candidates->data[i].p >= multiplied_min_p || filtered_candidates.size() < min_keep) {
filtered_candidates.push_back(candidates->data[i]);
kept_count++; // Increment the counter
}
}
// If not enough candidates meet the threshold, take the top 'min_keep' ones
if (kept_count < min_keep) {
if (filtered_candidates.size() < min_keep) {
std::sort(candidates->data, candidates->data + candidates->size,
[](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p; // Sort by probability in descending order