added parameter for DRY penalty range, separate from the original repetition penalty range

This commit is contained in:
l3utterfly 2024-04-29 10:20:17 +09:00
parent 75beda2a84
commit 85dadac483
2 changed files with 29 additions and 21 deletions

View file

@ -278,6 +278,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
const float dry_multiplier = params.dry_multiplier;
const float dry_base = params.dry_base;
const uint32_t dry_allowed_length = params.dry_allowed_length;
const uint32_t dry_penalty_last_n = params.dry_penalty_last_n;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
@ -308,8 +309,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
// apply repetition penalties
{
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
@ -319,14 +322,6 @@ static llama_token_data_array llama_sampling_prepare_impl(
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
// DRY penalties (multiplier > 0 means enabled)
if (dry_multiplier > 0.0f) {
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
}
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
@ -336,6 +331,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
}
}
}
}
// apply DRY penalties
{
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
if (penalty_tokens_used_size) {
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
}
}
// apply grammar checks before sampling logic
if (apply_grammar && ctx_sampling->grammar != NULL) {

View file

@ -44,6 +44,7 @@ typedef struct llama_sampling_params {
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
uint32_t dry_allowed_length = 2;
uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,