From 8852de34befa85b233aa4e29811a68a42be0673b Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Sat, 17 Feb 2024 14:36:47 +0100 Subject: [PATCH] server: ensure client request cannot override n_predict if set --- examples/server/server.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5cf1044d9..7798b8af5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -158,6 +158,7 @@ struct llama_client_slot int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; + int32_t n_predict = -1; int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens_processed = 0; @@ -409,6 +410,7 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; + slot.n_predict = params.n_predict; LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); @@ -545,6 +547,15 @@ struct llama_server_context slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { + // Might be better to reject the request with a 400 ? + LOG_WARNING("Max tokens to predict exceeds server configuration", { + {"params.n_predict", slot->params.n_predict}, + {"slot.n_predict", slot->n_predict}, + }); + slot->params.n_predict = slot->n_predict; + } + // infill if (data.count("input_prefix") != 0) { @@ -1052,6 +1063,7 @@ struct llama_server_context return json { {"n_ctx", slot.n_ctx}, + {"n_predict", slot.n_predict}, {"model", params.model_alias}, {"seed", slot.params.seed}, {"temperature", slot.sparams.temp},