server: ensure client request cannot override n_predict if set

This commit is contained in:
Pierrick HYMBERT 2024-02-17 14:36:47 +01:00
parent cf7137e8d6
commit 8852de34be

View file

@ -158,6 +158,7 @@ struct llama_client_slot
int32_t n_decoded = 0; int32_t n_decoded = 0;
int32_t n_remaining = -1; int32_t n_remaining = -1;
int32_t i_batch = -1; int32_t i_batch = -1;
int32_t n_predict = -1;
int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens = 0;
int32_t num_prompt_tokens_processed = 0; int32_t num_prompt_tokens_processed = 0;
@ -409,6 +410,7 @@ struct llama_server_context
slot.id = i; slot.id = i;
slot.n_ctx = n_ctx_slot; slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
@ -545,6 +547,15 @@ struct llama_server_context
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", {
{"params.n_predict", slot->params.n_predict},
{"slot.n_predict", slot->n_predict},
});
slot->params.n_predict = slot->n_predict;
}
// infill // infill
if (data.count("input_prefix") != 0) if (data.count("input_prefix") != 0)
{ {
@ -1052,6 +1063,7 @@ struct llama_server_context
return json { return json {
{"n_ctx", slot.n_ctx}, {"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict},
{"model", params.model_alias}, {"model", params.model_alias},
{"seed", slot.params.seed}, {"seed", slot.params.seed},
{"temperature", slot.sparams.temp}, {"temperature", slot.sparams.temp},