llama : simplify infill sampler

This commit is contained in:
Georgi Gerganov 2024-10-10 20:36:25 +03:00
parent 2e8c350a5f
commit 4b1bd81661
No known key found for this signature in database
GPG key ID: BF970631944C16B7
4 changed files with 22 additions and 11 deletions

View file

@ -117,8 +117,6 @@ struct common_sampler_params {
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
float infill_p = 0.80f;
float infill_p_eog = 0.01f;
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled

View file

@ -93,9 +93,7 @@ function! llama#fim(is_auto) abort
"\ 'stop': g:llama_config.stop,
\ 'n_predict': g:llama_config.n_predict,
\ 'penalty_last_n': 0,
\ 'top_k': 5,
\ 'infill_p': 0.20,
\ 'infill_p_eog': 0.001,
\ 'top_k': 100,
\ 'stream': v:false,
\ 'samplers': ["top_k", "infill"],
"\ 'cache_prompt': v:true,
@ -180,7 +178,7 @@ function! s:fim_auto()
call jobstop(s:current_job)
endif
if reltimefloat(reltime(s:t_fim_last)) < 0.001*250
if reltimefloat(reltime(s:t_fim_last)) < 500*0.001
if s:timer_fim != -1
call timer_stop(s:timer_fim)
let s:timer_fim = -1
@ -188,7 +186,7 @@ function! s:fim_auto()
endif
let s:t_fim_last = reltime()
let s:timer_fim = timer_start(250, {-> llama#fim(v:true)})
let s:timer_fim = timer_start(500, {-> llama#fim(v:true)})
endfunction

View file

@ -873,8 +873,6 @@ struct server_context {
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p);
slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
@ -1243,8 +1241,6 @@ struct server_context {
{"xtc_threshold", slot.sparams.xtc_threshold},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typ_p},
{"infill_p", slot.sparams.infill_p},
{"infill_p_eog", slot.sparams.infill_p_eog},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},

View file

@ -1792,6 +1792,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
<<<<<<< HEAD
=======
float p_max = 0.0f;
>>>>>>> af919ec1 (llama : simplify infill sampler)
float p_txt_sum = 0.0f;
float p_eog_sum = 0.0f;
@ -1803,12 +1807,20 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
}
}
<<<<<<< HEAD
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
=======
const float rat = p_txt_sum / p_eog_sum;
LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size);
if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) {
LLAMA_LOG_DEBUG("infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", p_txt_sum/p_eog_sum);
>>>>>>> af919ec1 (llama : simplify infill sampler)
// keep just the EOG tokens
const auto size_org = cur_p->size;
@ -1879,6 +1891,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
}
}
<<<<<<< HEAD
size_t n_non_eog = 0;
size_t size_org = cur_p->size;
@ -1895,6 +1908,12 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
if (cur_p->data[i].p < thold && !is_eog) {
continue;
=======
// mask non-EOG tokens with prob < 0.2
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
>>>>>>> af919ec1 (llama : simplify infill sampler)
}
if (!is_eog) {