sampling : rename penalty params + reduce size of "prev" vector
This commit is contained in:
parent
84ed48b473
commit
b526561583
7 changed files with 86 additions and 79 deletions
|
@ -80,10 +80,10 @@ llama_token llama_sampling_sample(
|
|||
const float top_p = params.top_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n;
|
||||
const float repeat_penalty = params.repeat_penalty;
|
||||
const float alpha_presence = params.presence_penalty;
|
||||
const float alpha_frequency = params.frequency_penalty;
|
||||
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||
const float penalty_repeat = params.penalty_repeat;
|
||||
const float penalty_freq = params.penalty_freq;
|
||||
const float penalty_present = params.penalty_present;
|
||||
const int mirostat = params.mirostat;
|
||||
const float mirostat_tau = params.mirostat_tau;
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
|
@ -118,8 +118,8 @@ llama_token llama_sampling_sample(
|
|||
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
||||
|
||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||
prev.data() + prev.size() - repeat_last_n,
|
||||
repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence);
|
||||
prev.data() + prev.size() - penalty_last_n,
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue