sampling : refactor + optimize penalties sampler (#10803)
* sampling : refactor + optimize penalties sampler ggml-ci * common : apply ignore_eos as logit bias ggml-ci * batched : remove penalties sampler * params : allow penalty_last_n == -1 to be equal to context size ggml-ci * common : by default, move the penalties at the end of the sampling chain ggml-ci * common : ignore all EOG tokens Co-authored-by: Diego Devesa <slarengh@gmail.com> * common : move back the penalties at the front of the sampling chain ggml-ci * readme : restore hint about --ignore-eos flag [no ci] * llama : minor ggml-ci * webui : update --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
4ddd199f6f
commit
644fd71b44
17 changed files with 111 additions and 152 deletions
|
@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_penalties(
|
||||
llama_n_vocab (model),
|
||||
llama_token_eos(model),
|
||||
llama_token_nl (model),
|
||||
params.penalty_last_n,
|
||||
params.penalty_repeat,
|
||||
params.penalty_freq,
|
||||
params.penalty_present,
|
||||
params.penalize_nl,
|
||||
params.ignore_eos));
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char*> c_breakers;
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto& str : params.dry_sequence_breakers) {
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
|
@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
|
@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||
default : return '?';
|
||||
}
|
||||
}
|
||||
|
@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||
default : return "";
|
||||
}
|
||||
}
|
||||
|
@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
|
@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue