From 0a1f7fb66d255edd94fb0a59da5fb95d66d418bb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Dec 2024 20:39:16 +0200 Subject: [PATCH] sampling : refactor + optimize penalties sampler ggml-ci --- common/arg.cpp | 7 -- common/common.h | 2 +- common/sampling.cpp | 19 ++- examples/batched/batched.cpp | 7 ++ examples/main/README.md | 7 +- examples/server/README.md | 5 - examples/server/public_legacy/index-new.html | 1 - examples/server/public_legacy/index.html | 2 - examples/server/server.cpp | 2 - examples/server/themes/buttons-top/index.html | 2 - examples/server/themes/wild/index.html | 2 - include/llama.h | 14 +-- src/llama-sampling.cpp | 115 ++++-------------- tests/test-sampling.cpp | 2 +- 14 files changed, 47 insertions(+), 140 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 39bc874c8..faac1db99 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -855,13 +855,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.ignore_eos = true; } ).set_sparam()); - add_opt(common_arg( - {"--penalize-nl"}, - string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"), - [](common_params & params) { - params.sampling.penalize_nl = true; - } - ).set_sparam()); add_opt(common_arg( {"--temp"}, "N", string_format("temperature (default: %.1f)", (double)params.sampling.temp), diff --git a/common/common.h b/common/common.h index 9e47b70a4..8d1bbeae1 100644 --- a/common/common.h +++ b/common/common.h @@ -95,6 +95,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_TEMPERATURE = 7, COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_INFILL = 9, + COMMON_SAMPLER_TYPE_PENALTIES = 10, }; // dimensionality reduction methods, used by cvector-generator @@ -130,7 +131,6 @@ struct common_params_sampling { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; bool no_perf = false; // disable performance metrics bool timing_per_token = false; diff --git a/common/sampling.cpp b/common/sampling.cpp index 0c4699a89..6ba57cc20 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -161,18 +161,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.size(), params.logit_bias.data())); - llama_sampler_chain_add(result->chain, - llama_sampler_init_penalties( - llama_n_vocab (model), - llama_token_eos(model), - llama_token_nl (model), - params.penalty_last_n, - params.penalty_repeat, - params.penalty_freq, - params.penalty_present, - params.penalize_nl, - params.ignore_eos)); - if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { switch (cnstr) { @@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co case COMMON_SAMPLER_TYPE_INFILL: llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; default: GGML_ASSERT(false && "unknown sampler type"); } @@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; + case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; default : return '?'; } } @@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; + case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; default : return ""; } } @@ -443,6 +436,7 @@ std::vector common_sampler_types_from_names(const std::vect { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, + { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, }; // since samplers names are written multiple ways @@ -489,6 +483,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, }; std::vector samplers; diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index ba219cd4b..a73740ea0 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -65,10 +65,17 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; llama_sampler * smpl = llama_sampler_chain_init(sparams); llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, + llama_sampler_init_penalties( + params.sampling.penalty_last_n, + params.sampling.penalty_repeat, + params.sampling.penalty_freq, + params.sampling.penalty_present)); llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); diff --git a/examples/main/README.md b/examples/main/README.md index 7787f7b11..ae3b13b7b 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -163,7 +163,7 @@ If the pause is undesirable, a value of -2 will stop generation immediately when The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full. -It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter. +It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. ### Temperature @@ -177,16 +177,11 @@ Example usage: `--temp 0` - `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled). - `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size). -- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty. The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1. The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`). -Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases. - -Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl` - ### DRY Repetition Penalty DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)). diff --git a/examples/server/README.md b/examples/server/README.md index a9443e56b..63a7bf43a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | -| `--penalize-nl` | penalize newline tokens (default: false) | | `--temp N` | temperature (default: 0.8) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled) | | `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | @@ -393,8 +392,6 @@ These words will not be included in the completion, so make sure to add them to `repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size. -`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true` - `presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled. `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. @@ -655,7 +652,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make "mirostat": 0, "mirostat_tau": 5.0, "mirostat_eta": 0.10000000149011612, - "penalize_nl": false, "stop": [], "max_tokens": -1, "n_keep": 0, @@ -845,7 +841,6 @@ Example: "mirostat": 0, "mirostat_tau": 5.0, "mirostat_eta": 0.10000000149011612, - "penalize_nl": false, "stop": [], "max_tokens": -1, "n_keep": 0, diff --git a/examples/server/public_legacy/index-new.html b/examples/server/public_legacy/index-new.html index 8bfa380e5..cbfbbdf28 100644 --- a/examples/server/public_legacy/index-new.html +++ b/examples/server/public_legacy/index-new.html @@ -39,7 +39,6 @@ temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower repeat_last_n: 0, // 0 = disable penalty, -1 = context size repeat_penalty: 1.0, // 1.0 = disabled - penalize_nl: false, // true only useful for infinite completion dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well dry_base: 1.75, // 0.0 = disabled dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well diff --git a/examples/server/public_legacy/index.html b/examples/server/public_legacy/index.html index a95f5c6df..75f39330a 100644 --- a/examples/server/public_legacy/index.html +++ b/examples/server/public_legacy/index.html @@ -303,7 +303,6 @@ temperature: 0.7, repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_penalty: 1.18, // 1.0 = disabled - penalize_nl: false, dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well dry_base: 1.75, // 0.0 = disabled dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well @@ -1006,7 +1005,6 @@ ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index aa8c31259..879b041b0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -135,7 +135,6 @@ struct slot_params { {"mirostat", sampling.mirostat}, {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, - {"penalize_nl", sampling.penalize_nl}, {"stop", antiprompt}, {"max_tokens", n_predict}, // User configured n_predict {"n_keep", n_keep}, @@ -226,7 +225,6 @@ struct server_task { params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl); params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); diff --git a/examples/server/themes/buttons-top/index.html b/examples/server/themes/buttons-top/index.html index 2797c37c9..3fb88fcc8 100644 --- a/examples/server/themes/buttons-top/index.html +++ b/examples/server/themes/buttons-top/index.html @@ -222,7 +222,6 @@ temperature: 0.7, repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_penalty: 1.18, // 1.0 = disabled - penalize_nl: false, top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled @@ -779,7 +778,6 @@ ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} diff --git a/examples/server/themes/wild/index.html b/examples/server/themes/wild/index.html index dbe23c402..73f36d4b2 100644 --- a/examples/server/themes/wild/index.html +++ b/examples/server/themes/wild/index.html @@ -225,7 +225,6 @@ temperature: 0.7, repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_penalty: 1.18, // 1.0 = disabled - penalize_nl: false, top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled @@ -782,7 +781,6 @@ ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} diff --git a/include/llama.h b/include/llama.h index c67988a33..efbb27d21 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1139,16 +1139,12 @@ extern "C" { const char * grammar_str, const char * grammar_root); + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( - int32_t n_vocab, // llama_n_vocab() - llama_token special_eos_id, // llama_token_eos() - llama_token linefeed_id, // llama_token_nl() - int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat, // 1.0 = disabled - float penalty_freq, // 0.0 = disabled - float penalty_present, // 0.0 = disabled - bool penalize_nl, // consider newlines as a repeatable token - bool ignore_eos); // ignore the end-of-sequence token + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present); // 0.0 = disabled /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 LLAMA_API struct llama_sampler * llama_sampler_init_dry( diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index fd8ca8a9e..8bbe751e0 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab // penalties struct llama_sampler_penalties { - const int32_t n_vocab; - const llama_token special_eos_id; - const llama_token linefeed_id; - const int32_t penalty_last_n; const float penalty_repeat; const float penalty_freq; const float penalty_present; - const bool penalize_nl; - const bool ignore_eos; - ring_buffer prev; + + // a frequency map to count token occurrences + std::unordered_map token_count; }; static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { @@ -1421,76 +1417,40 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to return; } + ctx->token_count[token]++; + + // if the ring buffer is full, remove the oldest token + if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { + const auto pop = ctx->prev.front(); + + ctx->token_count[pop]--; + if (ctx->token_count[pop] == 0) { + ctx->token_count.erase(pop); + } + } + ctx->prev.push_back(token); } static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; - if (ctx->ignore_eos) { - assert(ctx->special_eos_id >= 0); - - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { - cur_p->data[ctx->special_eos_id].logit = -INFINITY; - } else { - // else, search for the special EOS token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->special_eos_id) { - cur_p->data[i].logit = -INFINITY; - break; - } - } - } - } - if ((ctx->penalty_last_n == 0) || (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { return; } - bool nl_found = false; - size_t nl_idx = 0; - float nl_logit = -INFINITY; - if (!ctx->penalize_nl) { - assert(ctx->linefeed_id >= 0); - - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = ctx->linefeed_id; - nl_logit = cur_p->data[ctx->linefeed_id].logit; - } else { - // else, search for the linefeed token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = i; - nl_logit = cur_p->data[i].logit; - break; - } - } - } - } - - // Create a frequency map to count occurrences of each token in last_tokens - // TODO: optimize this by maintaining the token count in the sampler context - using llama_token_cnt = std::unordered_map; - llama_token_cnt token_count; - - for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { - token_count[ctx->prev.rat(i)]++; - } - // Apply frequency and presence penalties to the cur_p for (size_t i = 0; i < cur_p->size; ++i) { - const auto token_iter = token_count.find(cur_p->data[i].id); - if (token_iter == token_count.end()) { + const auto token_iter = ctx->token_count.find(cur_p->data[i].id); + if (token_iter == ctx->token_count.end()) { continue; } const int count = token_iter->second; + assert(count > 0); + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. if (cur_p->data[i].logit <= 0) { @@ -1503,30 +1463,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok } cur_p->sorted = false; - - if (!ctx->penalize_nl && nl_found) { - // restore the logit of the newline token if it was penalized - cur_p->data[nl_idx].logit = nl_logit; - } } static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; ctx->prev.clear(); + ctx->token_count.clear(); } static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; auto * result = llama_sampler_init_penalties( - ctx->n_vocab, - ctx->special_eos_id, - ctx->linefeed_id, ctx->penalty_last_n, ctx->penalty_repeat, ctx->penalty_freq, - ctx->penalty_present, - ctx->penalize_nl, - ctx->ignore_eos); + ctx->penalty_present); // copy the state { @@ -1552,38 +1503,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = { }; struct llama_sampler * llama_sampler_init_penalties( - int32_t n_vocab, - llama_token special_eos_id, - llama_token linefeed_id, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { - if (linefeed_id == LLAMA_TOKEN_NULL) { - penalize_nl = true; - } - - if (special_eos_id == LLAMA_TOKEN_NULL) { - ignore_eos = false; - } - + float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); return new llama_sampler { /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { - /* .n_vocab = */ n_vocab, - /* .special_eos_id = */ special_eos_id, - /* .linefeed_id = */ linefeed_id, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, /* .penalty_present = */ penalty_present, - /* .penalize_nl = */ penalize_nl, - /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n), + /* .token_count = */ {}, }, }; } @@ -1611,7 +1545,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std if (word.find(str) != std::string::npos) { token_sequences.emplace(token_id, std::vector()); } else { - size_t word_len = word.size(), str_len = str.size(); + size_t word_len = word.size(); + size_t str_len = str.size(); size_t pos = -1; while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { bool match = true; diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index e5c9e75e4..c0dcb4848 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -145,7 +145,7 @@ static void test_penalties( sampler_tester tester(probs, probs_expected); const size_t n_vocab = probs.size(); - auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false); + auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); for (size_t i = 0; i < last_tokens.size(); i++) { llama_sampler_accept(sampler, last_tokens[i]);