From 0d934ee517a7725a97977885d740c834ae06d49f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Mar 2024 16:16:20 +0200 Subject: [PATCH] server : construct batch with size of llama_n_batch --- examples/server/server.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fe18d88ed..d83dc4be0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -147,7 +147,7 @@ struct server_slot { int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; - int32_t n_predict = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; @@ -739,7 +739,13 @@ struct server_context { default_generation_settings_for_props = get_formated_generation(slots.front()); default_generation_settings_for_props["seed"] = -1; - batch = llama_batch_init(n_ctx, 0, params.n_parallel); + // the update_slots() logic will always submit a maximum of n_batch tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + + batch = llama_batch_init(n_batch, 0, params.n_parallel); + } metrics.init(); } @@ -1036,8 +1042,10 @@ struct server_context { llama_batch_add(batch, system_tokens[i], i, { 0 }, false); } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) { - const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_batch = llama_n_batch(ctx); + + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, batch.token + i, @@ -1226,7 +1234,7 @@ struct server_context { {"mirostat_eta", slot.sparams.mirostat_eta}, {"penalize_nl", slot.sparams.penalize_nl}, {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, + {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"n_keep", params.n_keep}, {"ignore_eos", ignore_eos}, {"stream", slot.params.stream},