added parameter for DRY penalty range, separate from the original repetition penalty range
This commit is contained in:
parent
75beda2a84
commit
85dadac483
2 changed files with 29 additions and 21 deletions
|
@ -278,6 +278,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
const float dry_multiplier = params.dry_multiplier;
|
const float dry_multiplier = params.dry_multiplier;
|
||||||
const float dry_base = params.dry_base;
|
const float dry_base = params.dry_base;
|
||||||
const uint32_t dry_allowed_length = params.dry_allowed_length;
|
const uint32_t dry_allowed_length = params.dry_allowed_length;
|
||||||
|
const uint32_t dry_penalty_last_n = params.dry_penalty_last_n;
|
||||||
|
|
||||||
auto & prev = ctx_sampling->prev;
|
auto & prev = ctx_sampling->prev;
|
||||||
auto & cur = ctx_sampling->cur;
|
auto & cur = ctx_sampling->cur;
|
||||||
|
@ -308,35 +309,41 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
// apply penalties
|
|
||||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
|
||||||
if (penalty_tokens_used_size) {
|
|
||||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
|
||||||
|
|
||||||
// repetition penalties
|
// apply repetition penalties
|
||||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
{
|
||||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
if (penalty_tokens_used_size) {
|
||||||
|
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||||
|
|
||||||
// DRY penalties (multiplier > 0 means enabled)
|
// repetition penalties
|
||||||
if (dry_multiplier > 0.0f) {
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
llama_sample_dry(ctx_main, &cur_p,
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||||
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
|
||||||
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||||
cur_p.data[idx].logit = nl_logit;
|
cur_p.data[idx].logit = nl_logit;
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// apply DRY penalties
|
||||||
|
{
|
||||||
|
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
|
||||||
|
if (penalty_tokens_used_size) {
|
||||||
|
llama_sample_dry(ctx_main, &cur_p,
|
||||||
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
|
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
||||||
|
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// apply grammar checks before sampling logic
|
// apply grammar checks before sampling logic
|
||||||
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
||||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||||
|
|
|
@ -41,9 +41,10 @@ typedef struct llama_sampling_params {
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||||
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
||||||
float dry_base = 1.75f;
|
float dry_base = 1.75f;
|
||||||
uint32_t dry_allowed_length = 2;
|
uint32_t dry_allowed_length = 2;
|
||||||
|
uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers_sequence = {
|
std::vector<llama_sampler_type> samplers_sequence = {
|
||||||
llama_sampler_type::TOP_K,
|
llama_sampler_type::TOP_K,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue