keep the minimum min_keep value to 1 in sampling

This commit is contained in:
zhenweijin 2024-09-23 16:30:42 +08:00
parent 37f8c7b4c9
commit 3578d09729

View file

@ -173,22 +173,23 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
if (params.temp > 0.0f) { if (params.temp > 0.0f) {
if (params.mirostat == 0) { if (params.mirostat == 0) {
size_t min_keep = std::max(1, params.min_keep);
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
switch (cnstr) { switch (cnstr) {
case GPT_SAMPLER_TYPE_TOP_K: case GPT_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break; break;
case GPT_SAMPLER_TYPE_TOP_P: case GPT_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, min_keep));
break; break;
case GPT_SAMPLER_TYPE_MIN_P: case GPT_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, min_keep));
break; break;
case GPT_SAMPLER_TYPE_TFS_Z: case GPT_SAMPLER_TYPE_TFS_Z:
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, min_keep));
break; break;
case GPT_SAMPLER_TYPE_TYPICAL_P: case GPT_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, min_keep));
break; break;
case GPT_SAMPLER_TYPE_TEMPERATURE: case GPT_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));