server : construct batch with size of llama_n_batch

This commit is contained in:
Georgi Gerganov 2024-03-13 16:16:20 +02:00
parent 015e1bfe64
commit 0d934ee517
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -147,7 +147,7 @@ struct server_slot {
int32_t n_decoded = 0; int32_t n_decoded = 0;
int32_t n_remaining = -1; int32_t n_remaining = -1;
int32_t i_batch = -1; int32_t i_batch = -1;
int32_t n_predict = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens_processed = 0;
@ -739,7 +739,13 @@ struct server_context {
default_generation_settings_for_props = get_formated_generation(slots.front()); default_generation_settings_for_props = get_formated_generation(slots.front());
default_generation_settings_for_props["seed"] = -1; default_generation_settings_for_props["seed"] = -1;
batch = llama_batch_init(n_ctx, 0, params.n_parallel); // the update_slots() logic will always submit a maximum of n_batch tokens
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx);
batch = llama_batch_init(n_batch, 0, params.n_parallel);
}
metrics.init(); metrics.init();
} }
@ -1036,8 +1042,10 @@ struct server_context {
llama_batch_add(batch, system_tokens[i], i, { 0 }, false); llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
} }
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) { const int32_t n_batch = llama_n_batch(ctx);
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
llama_batch batch_view = { llama_batch batch_view = {
n_tokens, n_tokens,
batch.token + i, batch.token + i,
@ -1226,7 +1234,7 @@ struct server_context {
{"mirostat_eta", slot.sparams.mirostat_eta}, {"mirostat_eta", slot.sparams.mirostat_eta},
{"penalize_nl", slot.sparams.penalize_nl}, {"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt}, {"stop", slot.params.antiprompt},
{"n_predict", slot.params.n_predict}, {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
{"n_keep", params.n_keep}, {"n_keep", params.n_keep},
{"ignore_eos", ignore_eos}, {"ignore_eos", ignore_eos},
{"stream", slot.params.stream}, {"stream", slot.params.stream},