llama : simplify infill sampler
This commit is contained in:
parent
2e8c350a5f
commit
4b1bd81661
4 changed files with 22 additions and 11 deletions
|
@ -117,8 +117,6 @@ struct common_sampler_params {
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
float infill_p = 0.80f;
|
|
||||||
float infill_p_eog = 0.01f;
|
|
||||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
|
|
|
@ -93,9 +93,7 @@ function! llama#fim(is_auto) abort
|
||||||
"\ 'stop': g:llama_config.stop,
|
"\ 'stop': g:llama_config.stop,
|
||||||
\ 'n_predict': g:llama_config.n_predict,
|
\ 'n_predict': g:llama_config.n_predict,
|
||||||
\ 'penalty_last_n': 0,
|
\ 'penalty_last_n': 0,
|
||||||
\ 'top_k': 5,
|
\ 'top_k': 100,
|
||||||
\ 'infill_p': 0.20,
|
|
||||||
\ 'infill_p_eog': 0.001,
|
|
||||||
\ 'stream': v:false,
|
\ 'stream': v:false,
|
||||||
\ 'samplers': ["top_k", "infill"],
|
\ 'samplers': ["top_k", "infill"],
|
||||||
"\ 'cache_prompt': v:true,
|
"\ 'cache_prompt': v:true,
|
||||||
|
@ -180,7 +178,7 @@ function! s:fim_auto()
|
||||||
call jobstop(s:current_job)
|
call jobstop(s:current_job)
|
||||||
endif
|
endif
|
||||||
|
|
||||||
if reltimefloat(reltime(s:t_fim_last)) < 0.001*250
|
if reltimefloat(reltime(s:t_fim_last)) < 500*0.001
|
||||||
if s:timer_fim != -1
|
if s:timer_fim != -1
|
||||||
call timer_stop(s:timer_fim)
|
call timer_stop(s:timer_fim)
|
||||||
let s:timer_fim = -1
|
let s:timer_fim = -1
|
||||||
|
@ -188,7 +186,7 @@ function! s:fim_auto()
|
||||||
endif
|
endif
|
||||||
|
|
||||||
let s:t_fim_last = reltime()
|
let s:t_fim_last = reltime()
|
||||||
let s:timer_fim = timer_start(250, {-> llama#fim(v:true)})
|
let s:timer_fim = timer_start(500, {-> llama#fim(v:true)})
|
||||||
endfunction
|
endfunction
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -873,8 +873,6 @@ struct server_context {
|
||||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p);
|
|
||||||
slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog);
|
|
||||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||||
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||||
|
@ -1243,8 +1241,6 @@ struct server_context {
|
||||||
{"xtc_threshold", slot.sparams.xtc_threshold},
|
{"xtc_threshold", slot.sparams.xtc_threshold},
|
||||||
{"tfs_z", slot.sparams.tfs_z},
|
{"tfs_z", slot.sparams.tfs_z},
|
||||||
{"typical_p", slot.sparams.typ_p},
|
{"typical_p", slot.sparams.typ_p},
|
||||||
{"infill_p", slot.sparams.infill_p},
|
|
||||||
{"infill_p_eog", slot.sparams.infill_p_eog},
|
|
||||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||||
{"presence_penalty", slot.sparams.penalty_present},
|
{"presence_penalty", slot.sparams.penalty_present},
|
||||||
|
|
|
@ -1792,6 +1792,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||||
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
float p_max = 0.0f;
|
||||||
|
>>>>>>> af919ec1 (llama : simplify infill sampler)
|
||||||
float p_txt_sum = 0.0f;
|
float p_txt_sum = 0.0f;
|
||||||
float p_eog_sum = 0.0f;
|
float p_eog_sum = 0.0f;
|
||||||
|
|
||||||
|
@ -1803,12 +1807,20 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
||||||
|
|
||||||
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
||||||
|
|
||||||
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
||||||
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
||||||
|
=======
|
||||||
|
const float rat = p_txt_sum / p_eog_sum;
|
||||||
|
LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
||||||
|
|
||||||
|
if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) {
|
||||||
|
LLAMA_LOG_DEBUG("infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", p_txt_sum/p_eog_sum);
|
||||||
|
>>>>>>> af919ec1 (llama : simplify infill sampler)
|
||||||
|
|
||||||
// keep just the EOG tokens
|
// keep just the EOG tokens
|
||||||
const auto size_org = cur_p->size;
|
const auto size_org = cur_p->size;
|
||||||
|
@ -1879,6 +1891,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
size_t n_non_eog = 0;
|
size_t n_non_eog = 0;
|
||||||
|
|
||||||
size_t size_org = cur_p->size;
|
size_t size_org = cur_p->size;
|
||||||
|
@ -1895,6 +1908,12 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||||
|
|
||||||
if (cur_p->data[i].p < thold && !is_eog) {
|
if (cur_p->data[i].p < thold && !is_eog) {
|
||||||
continue;
|
continue;
|
||||||
|
=======
|
||||||
|
// mask non-EOG tokens with prob < 0.2
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
>>>>>>> af919ec1 (llama : simplify infill sampler)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!is_eog) {
|
if (!is_eog) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue