sampling: fix top_k <= 0

This commit is contained in:
JohannesGaessler 2024-02-07 13:19:00 +01:00
parent 213d1439fa
commit 2f8e6078b0
3 changed files with 7 additions and 1 deletions

View file

@ -132,7 +132,7 @@ static void sampler_queue(
const float temp = params.temp; const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range; const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent; const float dynatemp_exponent = params.dynatemp_exponent;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const int32_t top_k = params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
const float min_p = params.min_p; const float min_p = params.min_p;
const float tfs_z = params.tfs_z; const float tfs_z = params.tfs_z;

View file

@ -8371,6 +8371,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
// return; // return;
// } // }
if (k <= 0) {
k = candidates->size;
}
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
k = std::max(k, (int) min_keep); k = std::max(k, (int) min_keep);

View file

@ -235,6 +235,8 @@ int main(void) {
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);