diff --git a/common/sampling.cpp b/common/sampling.cpp index f24665501..e9fdd10a4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + // repetition penalties const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - const bool penalize_nl = params.penalize_nl; + // DRY sampler parameters + const float dry_multiplier = params.dry_multiplier; + const float dry_base = params.dry_base; + const int dry_allowed_length = params.dry_allowed_length; + auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -309,10 +314,20 @@ static llama_token_data_array llama_sampling_prepare_impl( if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + // repetition penalties llama_sample_repetition_penalties(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + // DRY penalties (multiplier > 0 means enabled) + if(dry_multiplier > 0.0f) { + llama_sample_dry(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, + params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size()); + } + + if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { diff --git a/common/sampling.h b/common/sampling.h index cf7081e36..bfc338ef7 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,6 +41,9 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_base = 1.75f; + int dry_allowed_length = 2; std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -61,6 +64,7 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens std::vector penalty_prompt_tokens; + std::vector dry_sequence_breakers; // sequence breakers for the DRY sampler bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/llama.h b/llama.h index 0eb2a1e9a..0c6b86c16 100644 --- a/llama.h +++ b/llama.h @@ -924,6 +924,18 @@ extern "C" { float p, size_t min_keep); + /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + int last_tokens_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const llama_token * seq_breakers, + int seq_breakers_size); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx,