server: do not truncate prompt tokens if self-extend through group attention is enabled

This commit is contained in:
Pierrick HYMBERT 2024-03-02 13:52:52 +01:00
parent 60113da241
commit 616d7e9a9b

View file

@ -421,7 +421,7 @@ struct llama_server_context
// create slots
all_slots_are_idle = true;
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
const int32_t n_ctx_slot = std::min(n_ctx / params.n_parallel, llama_n_ctx_train(model)); // FIXME @ggerganov @phymbert To be discussed
LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
for (int i = 0; i < params.n_parallel; i++)
@ -441,8 +441,8 @@ struct llama_server_context
const int ga_w = params.grp_attn_w;
if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
@ -1709,8 +1709,8 @@ struct llama_server_context
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
// if input prompt is too big, truncate it
if (slot.n_prompt_tokens >= slot.n_ctx)
// if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
@ -1785,9 +1785,11 @@ struct llama_server_context
}
LOG_INFO("slot progression", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
}