server : fix context shift + simplify self-extend

This commit is contained in:
Georgi Gerganov 2024-01-29 14:58:40 +02:00
parent 7359016c7c
commit 51bb7f0eef
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -188,8 +188,6 @@ struct llama_client_slot
int32_t ga_n = 1;// group-attention factor int32_t ga_n = 1;// group-attention factor
int32_t ga_w = 512; // group-attention width int32_t ga_w = 512; // group-attention width
int32_t n_past_se = 0; // self-extend
// multimodal // multimodal
std::vector<slot_image> images; std::vector<slot_image> images;
@ -219,7 +217,7 @@ struct llama_client_slot
sent_token_probs_index = 0; sent_token_probs_index = 0;
infill = false; infill = false;
ga_i = 0; ga_i = 0;
n_past_se = 0;
generated_token_probs.clear(); generated_token_probs.clear();
for (slot_image & img : images) for (slot_image & img : images)
@ -1364,18 +1362,18 @@ struct llama_server_context
kv_cache_clear(); kv_cache_clear();
} }
return true; return true;
} else { }
task_server task; task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE; task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1; task.target_id = -1;
queue_tasks.post(task); queue_tasks.post(task);
}
for (llama_client_slot &slot : slots) for (llama_client_slot &slot : slots)
{ {
if (slot.ga_n == 1) if (slot.ga_n == 1)
{ {
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{ {
// Shift context // Shift context
const int n_left = slot.n_past - slot.params.n_keep - 1; const int n_left = slot.n_past - slot.params.n_keep - 1;
@ -1428,8 +1426,7 @@ struct llama_server_context
slot.i_batch = batch.n_tokens; slot.i_batch = batch.n_tokens;
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
} }
@ -1527,7 +1524,6 @@ struct llama_server_context
llama_sampling_reset(slot.ctx_sampling); llama_sampling_reset(slot.ctx_sampling);
slot.n_past = 0; slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0; slot.ga_i = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens; slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
} }
@ -1557,7 +1553,7 @@ struct llama_server_context
} }
slot_npast++; slot_npast++;
} }
slot.n_past_se = slot_npast; slot.n_past = slot_npast;
slot.ga_i = ga_i; slot.ga_i = ga_i;
} }
@ -1577,7 +1573,7 @@ struct llama_server_context
slot.n_past--; slot.n_past--;
if (slot.ga_i > 0) if (slot.ga_i > 0)
{ {
slot.n_past_se--; slot.n_past--;
} }
} }
@ -1591,7 +1587,6 @@ struct llama_server_context
// process the prefix of first image // process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int ga_i = slot.ga_i; int ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n; int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w; int32_t ga_w = slot.ga_w;
@ -1599,14 +1594,14 @@ struct llama_server_context
{ {
if (slot.ga_n != 1) if (slot.ga_n != 1)
{ {
while (slot_npast >= ga_i + ga_w) { while (slot.n_past >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1); const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd; slot.n_past -= bd;
ga_i += ga_w/ga_n; ga_i += ga_w/ga_n;
} }
} }
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id }, false);
slot_npast += 1; slot.n_past += 1;
} }
if (has_images && !ingest_images(slot, n_batch)) if (has_images && !ingest_images(slot, n_batch))
@ -1642,28 +1637,28 @@ struct llama_server_context
if (slot.ga_n != 1) if (slot.ga_n != 1)
{ {
// context extension via Self-Extend // context extension via Self-Extend
while (slot.n_past_se >= slot.ga_i + slot.ga_w) while (slot.n_past >= slot.ga_i + slot.ga_w)
{ {
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past + ib * bd, dd);
slot.n_past_se -= bd; slot.n_past -= bd;
slot.ga_i += slot.ga_w / slot.ga_n; slot.ga_i += slot.ga_w / slot.ga_n;
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i);
} }
slot.n_past_se += n_tokens; slot.n_past += n_tokens;
} }
} }
llama_batch batch_view = llama_batch batch_view =