sampling : rename penalty params + reduce size of "prev" vector

This commit is contained in:
Georgi Gerganov 2023-10-20 17:47:13 +03:00
parent 84ed48b473
commit b526561583
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 86 additions and 79 deletions

View file

@ -1010,30 +1010,30 @@ static json format_generation_settings(llama_server_context &llama)
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{
{"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", llama.params.seed},
{"temp", sparams.temp},
{"top_k", sparams.top_k},
{"top_p", sparams.top_p},
{"tfs_z", sparams.tfs_z},
{"typical_p", sparams.typical_p},
{"repeat_last_n", sparams.repeat_last_n},
{"repeat_penalty", sparams.repeat_penalty},
{"presence_penalty", sparams.presence_penalty},
{"frequency_penalty", sparams.frequency_penalty},
{"mirostat", sparams.mirostat},
{"mirostat_tau", sparams.mirostat_tau},
{"mirostat_eta", sparams.mirostat_eta},
{"penalize_nl", sparams.penalize_nl},
{"stop", llama.params.antiprompt},
{"n_predict", llama.params.n_predict},
{"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos},
{"stream", llama.stream},
{"logit_bias", sparams.logit_bias},
{"n_probs", sparams.n_probs},
{"grammar", llama.params.sparams.grammar},
{"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", llama.params.seed},
{"temp", sparams.temp},
{"top_k", sparams.top_k},
{"top_p", sparams.top_p},
{"tfs_z", sparams.tfs_z},
{"typical_p", sparams.typical_p},
{"repeat_last_n", sparams.penalty_last_n},
{"repeat_penalty", sparams.penalty_repeat},
{"frequency_penalty", sparams.penalty_freq},
{"presence_penalty", sparams.penalty_present},
{"mirostat", sparams.mirostat},
{"mirostat_tau", sparams.mirostat_tau},
{"mirostat_eta", sparams.mirostat_eta},
{"penalize_nl", sparams.penalize_nl},
{"stop", llama.params.antiprompt},
{"n_predict", llama.params.n_predict},
{"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos},
{"stream", llama.stream},
{"logit_bias", sparams.logit_bias},
{"n_probs", sparams.n_probs},
{"grammar", llama.params.sparams.grammar},
};
}
@ -1134,25 +1134,25 @@ static void parse_options_completion(const json &body, llama_server_context &lla
auto & params = llama.params;
auto & sparams = llama.params.sparams;
llama.stream = json_value(body, "stream", false);
params.n_predict = json_value(body, "n_predict", default_params.n_predict);
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
sparams.temp = json_value(body, "temperature", default_sparams.temp);
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
params.n_keep = json_value(body, "n_keep", default_params.n_keep);
params.seed = json_value(body, "seed", default_params.seed);
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
llama.stream = json_value(body, "stream", false);
params.n_predict = json_value(body, "n_predict", default_params.n_predict);
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
sparams.temp = json_value(body, "temperature", default_sparams.temp);
sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat);
sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq);
sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present);
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
params.n_keep = json_value(body, "n_keep", default_params.n_keep);
params.seed = json_value(body, "seed", default_params.seed);
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
if (body.count("prompt") != 0)
{