sampling : rename penalty params + reduce size of "prev" vector

This commit is contained in:
Georgi Gerganov 2023-10-20 17:47:13 +03:00
parent 84ed48b473
commit b526561583
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 86 additions and 79 deletions

View file

@ -80,10 +80,10 @@ llama_token llama_sampling_sample(
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
@ -118,8 +118,8 @@ llama_token llama_sampling_sample(
const float nl_logit = logits[llama_token_nl(ctx_main)];
llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - repeat_last_n,
repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence);
prev.data() + prev.size() - penalty_last_n,
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {