added parameter for DRY penalty range, separate from the original repetition penalty range
This commit is contained in:
parent
75beda2a84
commit
85dadac483
2 changed files with 29 additions and 21 deletions
|
@ -278,6 +278,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
const float dry_multiplier = params.dry_multiplier;
|
||||
const float dry_base = params.dry_base;
|
||||
const uint32_t dry_allowed_length = params.dry_allowed_length;
|
||||
const uint32_t dry_penalty_last_n = params.dry_penalty_last_n;
|
||||
|
||||
auto & prev = ctx_sampling->prev;
|
||||
auto & cur = ctx_sampling->cur;
|
||||
|
@ -308,35 +309,41 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
|
||||
// apply penalties
|
||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||
if (penalty_tokens_used_size) {
|
||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||
|
||||
// repetition penalties
|
||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||
// apply repetition penalties
|
||||
{
|
||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||
if (penalty_tokens_used_size) {
|
||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||
|
||||
// DRY penalties (multiplier > 0 means enabled)
|
||||
if (dry_multiplier > 0.0f) {
|
||||
llama_sample_dry(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
||||
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
|
||||
}
|
||||
// repetition penalties
|
||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||
cur_p.data[idx].logit = nl_logit;
|
||||
break;
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||
cur_p.data[idx].logit = nl_logit;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply DRY penalties
|
||||
{
|
||||
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
|
||||
if (penalty_tokens_used_size) {
|
||||
llama_sample_dry(ctx_main, &cur_p,
|
||||
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
|
||||
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
|
||||
}
|
||||
}
|
||||
|
||||
// apply grammar checks before sampling logic
|
||||
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||
|
|
|
@ -41,9 +41,10 @@ typedef struct llama_sampling_params {
|
|||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
||||
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
|
||||
float dry_base = 1.75f;
|
||||
uint32_t dry_allowed_length = 2;
|
||||
uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::TOP_K,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue