sampling : refactor + optimize penalties sampler

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-12 20:39:16 +02:00
parent 4ddd199f6f
commit 0a1f7fb66d
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
14 changed files with 47 additions and 140 deletions

View file

@ -855,13 +855,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.ignore_eos = true; params.sampling.ignore_eos = true;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg(
{"--penalize-nl"},
string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
[](common_params & params) {
params.sampling.penalize_nl = true;
}
).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--temp"}, "N", {"--temp"}, "N",
string_format("temperature (default: %.1f)", (double)params.sampling.temp), string_format("temperature (default: %.1f)", (double)params.sampling.temp),

View file

@ -95,6 +95,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_TEMPERATURE = 7, COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
@ -130,7 +131,6 @@ struct common_params_sampling {
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false; bool ignore_eos = false;
bool no_perf = false; // disable performance metrics bool no_perf = false; // disable performance metrics
bool timing_per_token = false; bool timing_per_token = false;

View file

@ -161,18 +161,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.size(), params.logit_bias.size(),
params.logit_bias.data())); params.logit_bias.data()));
llama_sampler_chain_add(result->chain,
llama_sampler_init_penalties(
llama_n_vocab (model),
llama_token_eos(model),
llama_token_nl (model),
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos));
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
switch (cnstr) { switch (cnstr) {
@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default: default:
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_INFILL: return 'i';
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
default : return '?'; default : return '?';
} }
} }
@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_INFILL: return "infill";
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
default : return ""; default : return "";
} }
} }
@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC }, { "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL }, { "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
}; };
// since samplers names are written multiple ways // since samplers names are written multiple ways
@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
}; };
std::vector<common_sampler_type> samplers; std::vector<common_sampler_type> samplers;

View file

@ -65,10 +65,17 @@ int main(int argc, char ** argv) {
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_new_context_with_model(model, ctx_params);
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * smpl = llama_sampler_chain_init(sparams); llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl,
llama_sampler_init_penalties(
params.sampling.penalty_last_n,
params.sampling.penalty_repeat,
params.sampling.penalty_freq,
params.sampling.penalty_present));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));

View file

@ -163,7 +163,7 @@ If the pause is undesirable, a value of -2 will stop generation immediately when
The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full. The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter. It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value.
### Temperature ### Temperature
@ -177,16 +177,11 @@ Example usage: `--temp 0`
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled). - `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size). - `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1. The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`). The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
### DRY Repetition Penalty ### DRY Repetition Penalty
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)). DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).

View file

@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) | | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
| `--penalize-nl` | penalize newline tokens (default: false) |
| `--temp N` | temperature (default: 0.8) | | `--temp N` | temperature (default: 0.8) |
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | | `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
@ -393,8 +392,6 @@ These words will not be included in the completion, so make sure to add them to
`repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size. `repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true`
`presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled. `presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
@ -655,7 +652,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make
"mirostat": 0, "mirostat": 0,
"mirostat_tau": 5.0, "mirostat_tau": 5.0,
"mirostat_eta": 0.10000000149011612, "mirostat_eta": 0.10000000149011612,
"penalize_nl": false,
"stop": [], "stop": [],
"max_tokens": -1, "max_tokens": -1,
"n_keep": 0, "n_keep": 0,
@ -845,7 +841,6 @@ Example:
"mirostat": 0, "mirostat": 0,
"mirostat_tau": 5.0, "mirostat_tau": 5.0,
"mirostat_eta": 0.10000000149011612, "mirostat_eta": 0.10000000149011612,
"penalize_nl": false,
"stop": [], "stop": [],
"max_tokens": -1, "max_tokens": -1,
"n_keep": 0, "n_keep": 0,

View file

@ -39,7 +39,6 @@
temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
repeat_last_n: 0, // 0 = disable penalty, -1 = context size repeat_last_n: 0, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.0, // 1.0 = disabled repeat_penalty: 1.0, // 1.0 = disabled
penalize_nl: false, // true only useful for infinite completion
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
dry_base: 1.75, // 0.0 = disabled dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well

View file

@ -303,7 +303,6 @@
temperature: 0.7, temperature: 0.7,
repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_last_n: 256, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.18, // 1.0 = disabled repeat_penalty: 1.18, // 1.0 = disabled
penalize_nl: false,
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
dry_base: 1.75, // 0.0 = disabled dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
@ -1006,7 +1005,6 @@
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

View file

@ -135,7 +135,6 @@ struct slot_params {
{"mirostat", sampling.mirostat}, {"mirostat", sampling.mirostat},
{"mirostat_tau", sampling.mirostat_tau}, {"mirostat_tau", sampling.mirostat_tau},
{"mirostat_eta", sampling.mirostat_eta}, {"mirostat_eta", sampling.mirostat_eta},
{"penalize_nl", sampling.penalize_nl},
{"stop", antiprompt}, {"stop", antiprompt},
{"max_tokens", n_predict}, // User configured n_predict {"max_tokens", n_predict}, // User configured n_predict
{"n_keep", n_keep}, {"n_keep", n_keep},
@ -226,7 +225,6 @@ struct server_task {
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);

View file

@ -222,7 +222,6 @@
temperature: 0.7, temperature: 0.7,
repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_last_n: 256, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.18, // 1.0 = disabled repeat_penalty: 1.18, // 1.0 = disabled
penalize_nl: false,
top_k: 40, // <= 0 to use vocab size top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled min_p: 0.05, // 0 = disabled
@ -779,7 +778,6 @@
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

View file

@ -225,7 +225,6 @@
temperature: 0.7, temperature: 0.7,
repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_last_n: 256, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.18, // 1.0 = disabled repeat_penalty: 1.18, // 1.0 = disabled
penalize_nl: false,
top_k: 40, // <= 0 to use vocab size top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled min_p: 0.05, // 0 = disabled
@ -782,7 +781,6 @@
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

View file

@ -1139,16 +1139,12 @@ extern "C" {
const char * grammar_str, const char * grammar_str,
const char * grammar_root); const char * grammar_root);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties( LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
int32_t n_vocab, // llama_n_vocab()
llama_token special_eos_id, // llama_token_eos()
llama_token linefeed_id, // llama_token_nl()
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat, // 1.0 = disabled float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled float penalty_freq, // 0.0 = disabled
float penalty_present, // 0.0 = disabled float penalty_present); // 0.0 = disabled
bool penalize_nl, // consider newlines as a repeatable token
bool ignore_eos); // ignore the end-of-sequence token
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
LLAMA_API struct llama_sampler * llama_sampler_init_dry( LLAMA_API struct llama_sampler * llama_sampler_init_dry(

View file

@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
// penalties // penalties
struct llama_sampler_penalties { struct llama_sampler_penalties {
const int32_t n_vocab;
const llama_token special_eos_id;
const llama_token linefeed_id;
const int32_t penalty_last_n; const int32_t penalty_last_n;
const float penalty_repeat; const float penalty_repeat;
const float penalty_freq; const float penalty_freq;
const float penalty_present; const float penalty_present;
const bool penalize_nl;
const bool ignore_eos;
ring_buffer<llama_token> prev; ring_buffer<llama_token> prev;
// a frequency map to count token occurrences
std::unordered_map<llama_token, int> token_count;
}; };
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@ -1421,76 +1417,40 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
return; return;
} }
ctx->token_count[token]++;
// if the ring buffer is full, remove the oldest token
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
const auto pop = ctx->prev.front();
ctx->token_count[pop]--;
if (ctx->token_count[pop] == 0) {
ctx->token_count.erase(pop);
}
}
ctx->prev.push_back(token); ctx->prev.push_back(token);
} }
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx; auto * ctx = (llama_sampler_penalties *) smpl->ctx;
if (ctx->ignore_eos) {
assert(ctx->special_eos_id >= 0);
// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
} else {
// else, search for the special EOS token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->special_eos_id) {
cur_p->data[i].logit = -INFINITY;
break;
}
}
}
}
if ((ctx->penalty_last_n == 0) || if ((ctx->penalty_last_n == 0) ||
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
return; return;
} }
bool nl_found = false;
size_t nl_idx = 0;
float nl_logit = -INFINITY;
if (!ctx->penalize_nl) {
assert(ctx->linefeed_id >= 0);
// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = ctx->linefeed_id;
nl_logit = cur_p->data[ctx->linefeed_id].logit;
} else {
// else, search for the linefeed token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = i;
nl_logit = cur_p->data[i].logit;
break;
}
}
}
}
// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
using llama_token_cnt = std::unordered_map<llama_token, int>;
llama_token_cnt token_count;
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
token_count[ctx->prev.rat(i)]++;
}
// Apply frequency and presence penalties to the cur_p // Apply frequency and presence penalties to the cur_p
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
const auto token_iter = token_count.find(cur_p->data[i].id); const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
if (token_iter == token_count.end()) { if (token_iter == ctx->token_count.end()) {
continue; continue;
} }
const int count = token_iter->second; const int count = token_iter->second;
assert(count > 0);
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing. // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (cur_p->data[i].logit <= 0) { if (cur_p->data[i].logit <= 0) {
@ -1503,30 +1463,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
} }
cur_p->sorted = false; cur_p->sorted = false;
if (!ctx->penalize_nl && nl_found) {
// restore the logit of the newline token if it was penalized
cur_p->data[nl_idx].logit = nl_logit;
}
} }
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx; auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.clear(); ctx->prev.clear();
ctx->token_count.clear();
} }
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties( auto * result = llama_sampler_init_penalties(
ctx->n_vocab,
ctx->special_eos_id,
ctx->linefeed_id,
ctx->penalty_last_n, ctx->penalty_last_n,
ctx->penalty_repeat, ctx->penalty_repeat,
ctx->penalty_freq, ctx->penalty_freq,
ctx->penalty_present, ctx->penalty_present);
ctx->penalize_nl,
ctx->ignore_eos);
// copy the state // copy the state
{ {
@ -1552,38 +1503,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
}; };
struct llama_sampler * llama_sampler_init_penalties( struct llama_sampler * llama_sampler_init_penalties(
int32_t n_vocab,
llama_token special_eos_id,
llama_token linefeed_id,
int32_t penalty_last_n, int32_t penalty_last_n,
float penalty_repeat, float penalty_repeat,
float penalty_freq, float penalty_freq,
float penalty_present, float penalty_present) {
bool penalize_nl,
bool ignore_eos) {
if (linefeed_id == LLAMA_TOKEN_NULL) {
penalize_nl = true;
}
if (special_eos_id == LLAMA_TOKEN_NULL) {
ignore_eos = false;
}
penalty_last_n = std::max(penalty_last_n, 0); penalty_last_n = std::max(penalty_last_n, 0);
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i, /* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties { /* .ctx = */ new llama_sampler_penalties {
/* .n_vocab = */ n_vocab,
/* .special_eos_id = */ special_eos_id,
/* .linefeed_id = */ linefeed_id,
/* .penalty_last_n = */ penalty_last_n, /* .penalty_last_n = */ penalty_last_n,
/* .penalty_repeat = */ penalty_repeat, /* .penalty_repeat = */ penalty_repeat,
/* .penalty_freq = */ penalty_freq, /* .penalty_freq = */ penalty_freq,
/* .penalty_present = */ penalty_present, /* .penalty_present = */ penalty_present,
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n), /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
}, },
}; };
} }
@ -1611,7 +1545,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
if (word.find(str) != std::string::npos) { if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>()); token_sequences.emplace(token_id, std::vector<llama_token>());
} else { } else {
size_t word_len = word.size(), str_len = str.size(); size_t word_len = word.size();
size_t str_len = str.size();
size_t pos = -1; size_t pos = -1;
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
bool match = true; bool match = true;

View file

@ -145,7 +145,7 @@ static void test_penalties(
sampler_tester tester(probs, probs_expected); sampler_tester tester(probs, probs_expected);
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false); auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
for (size_t i = 0; i < last_tokens.size(); i++) { for (size_t i = 0; i < last_tokens.size(); i++) {
llama_sampler_accept(sampler, last_tokens[i]); llama_sampler_accept(sampler, last_tokens[i]);