From 85dadac483df1ba5071362f8284599fb6172047b Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 29 Apr 2024 10:20:17 +0900 Subject: [PATCH] added parameter for DRY penalty range, separate from the original repetition penalty range --- common/sampling.cpp | 47 ++++++++++++++++++++++++++------------------- common/sampling.h | 3 ++- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 92cd76e1d..f400aa7fb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -278,6 +278,7 @@ static llama_token_data_array llama_sampling_prepare_impl( const float dry_multiplier = params.dry_multiplier; const float dry_base = params.dry_base; const uint32_t dry_allowed_length = params.dry_allowed_length; + const uint32_t dry_penalty_last_n = params.dry_penalty_last_n; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -308,35 +309,41 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; - const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); - if (penalty_tokens_used_size) { - const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - // repetition penalties - llama_sample_repetition_penalties(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + // apply repetition penalties + { + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - // DRY penalties (multiplier > 0 means enabled) - if (dry_multiplier > 0.0f) { - llama_sample_dry(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, - params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); - } + // repetition penalties + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { - cur_p.data[idx].logit = nl_logit; - break; + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } } } } } + // apply DRY penalties + { + const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); + if (penalty_tokens_used_size) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); + } + } + // apply grammar checks before sampling logic if (apply_grammar && ctx_sampling->grammar != NULL) { llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); diff --git a/common/sampling.h b/common/sampling.h index 09df8edc3..4ad726c89 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,9 +41,10 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context - float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f float dry_base = 1.75f; uint32_t dry_allowed_length = 2; + uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) std::vector samplers_sequence = { llama_sampler_type::TOP_K,