From 4b1bd81661142cb8c9f768e465befbd678f64278 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 10 Oct 2024 20:36:25 +0300 Subject: [PATCH] llama : simplify infill sampler --- common/common.h | 2 -- examples/llama.vim | 8 +++----- examples/server/server.cpp | 4 ---- src/llama-sampling.cpp | 19 +++++++++++++++++++ 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/common/common.h b/common/common.h index 2fb92ae14..5ca8fd391 100644 --- a/common/common.h +++ b/common/common.h @@ -117,8 +117,6 @@ struct common_sampler_params { float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float dynatemp_range = 0.00f; // 0.0 = disabled float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - float infill_p = 0.80f; - float infill_p_eog = 0.01f; int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled diff --git a/examples/llama.vim b/examples/llama.vim index 3f747b360..c89ddea65 100644 --- a/examples/llama.vim +++ b/examples/llama.vim @@ -93,9 +93,7 @@ function! llama#fim(is_auto) abort "\ 'stop': g:llama_config.stop, \ 'n_predict': g:llama_config.n_predict, \ 'penalty_last_n': 0, - \ 'top_k': 5, - \ 'infill_p': 0.20, - \ 'infill_p_eog': 0.001, + \ 'top_k': 100, \ 'stream': v:false, \ 'samplers': ["top_k", "infill"], "\ 'cache_prompt': v:true, @@ -180,7 +178,7 @@ function! s:fim_auto() call jobstop(s:current_job) endif - if reltimefloat(reltime(s:t_fim_last)) < 0.001*250 + if reltimefloat(reltime(s:t_fim_last)) < 500*0.001 if s:timer_fim != -1 call timer_stop(s:timer_fim) let s:timer_fim = -1 @@ -188,7 +186,7 @@ function! s:fim_auto() endif let s:t_fim_last = reltime() - let s:timer_fim = timer_start(250, {-> llama#fim(v:true)}) + let s:timer_fim = timer_start(500, {-> llama#fim(v:true)}) endfunction diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e9621ba93..3992108e7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -873,8 +873,6 @@ struct server_context { slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p); - slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); @@ -1243,8 +1241,6 @@ struct server_context { {"xtc_threshold", slot.sparams.xtc_threshold}, {"tfs_z", slot.sparams.tfs_z}, {"typical_p", slot.sparams.typ_p}, - {"infill_p", slot.sparams.infill_p}, - {"infill_p_eog", slot.sparams.infill_p_eog}, {"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d71516153..4a5b922c4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1792,6 +1792,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); } +<<<<<<< HEAD +======= + float p_max = 0.0f; +>>>>>>> af919ec1 (llama : simplify infill sampler) float p_txt_sum = 0.0f; float p_eog_sum = 0.0f; @@ -1803,12 +1807,20 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ } } +<<<<<<< HEAD const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat); LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size); if (3*p_eog_sum*cur_p->size > p_txt_sum) { LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum); +======= + const float rat = p_txt_sum / p_eog_sum; + LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size); + + if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) { + LLAMA_LOG_DEBUG("infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", p_txt_sum/p_eog_sum); +>>>>>>> af919ec1 (llama : simplify infill sampler) // keep just the EOG tokens const auto size_org = cur_p->size; @@ -1879,6 +1891,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ } } +<<<<<<< HEAD size_t n_non_eog = 0; size_t size_org = cur_p->size; @@ -1895,6 +1908,12 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ if (cur_p->data[i].p < thold && !is_eog) { continue; +======= + // mask non-EOG tokens with prob < 0.2 + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + cur_p->data[i].logit = -INFINITY; +>>>>>>> af919ec1 (llama : simplify infill sampler) } if (!is_eog) {