llama : fix top-p sampling to match the canonical definition (#1953)

* Fix top-p sampling to match the standard definition (smallest set that has probability mass at least p, not largest set with probability mass less than p)

* top-p: correct gt to gte

* add test for correct top-p behavior
This commit is contained in:
Alex Renda 2023-06-24 03:15:01 -07:00 committed by GitHub
parent 527b6fba1d
commit b061ba9e2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 3 deletions

View file

@ -2015,9 +2015,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
for (size_t i = 0; i < candidates->size; ++i) {
cum_sum += candidates->data[i].p;
// Check if the running sum is greater than p or if we have kept at least min_keep tokens
if (cum_sum > p && i >= min_keep) {
last_idx = i;
// Check if the running sum is at least p or if we have kept at least min_keep tokens
// we set the last index to i+1 to indicate that the current iterate should be included in the set
if (cum_sum >= p && i + 1 >= min_keep) {
last_idx = i + 1;
break;
}
}