server : rever n_past_se changes
This commit is contained in:
parent
d0e10bf1b2
commit
05350f2862
1 changed files with 47 additions and 33 deletions
|
@ -188,6 +188,8 @@ struct llama_client_slot
|
|||
int32_t ga_n = 1; // group-attention factor
|
||||
int32_t ga_w = 512; // group-attention width
|
||||
|
||||
int32_t n_past_se = 0; // self-extend
|
||||
|
||||
// multimodal
|
||||
std::vector<slot_image> images;
|
||||
|
||||
|
@ -217,6 +219,7 @@ struct llama_client_slot
|
|||
sent_token_probs_index = 0;
|
||||
infill = false;
|
||||
ga_i = 0;
|
||||
n_past_se = 0;
|
||||
|
||||
generated_token_probs.clear();
|
||||
|
||||
|
@ -1293,7 +1296,8 @@ struct llama_server_context
|
|||
for (llama_client_slot &slot : slots)
|
||||
{
|
||||
slot.cache_tokens.clear();
|
||||
slot.n_past = 0;
|
||||
slot.n_past = 0;
|
||||
slot.n_past_se = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1427,9 +1431,11 @@ struct llama_server_context
|
|||
|
||||
slot.i_batch = batch.n_tokens;
|
||||
|
||||
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||
|
||||
// TODO: we always have to take into account the "system_tokens"
|
||||
// this is not great and needs to be improved somehow
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
||||
slot.n_past += 1;
|
||||
}
|
||||
|
||||
|
@ -1526,6 +1532,7 @@ struct llama_server_context
|
|||
llama_sampling_reset(slot.ctx_sampling);
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
||||
}
|
||||
|
@ -1540,6 +1547,25 @@ struct llama_server_context
|
|||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||
|
||||
if (slot.ga_n != 1)
|
||||
{
|
||||
int ga_i = 0;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
int32_t slot_npast = 0;
|
||||
for (int k = 0; k < slot.n_past; ++k)
|
||||
{
|
||||
while (slot_npast >= ga_i + ga_w) {
|
||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
slot_npast -= bd;
|
||||
ga_i += ga_w/ga_n;
|
||||
}
|
||||
slot_npast++;
|
||||
}
|
||||
slot.n_past_se = slot_npast;
|
||||
slot.ga_i = ga_i;
|
||||
}
|
||||
|
||||
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||
}
|
||||
|
||||
|
@ -1554,6 +1580,10 @@ struct llama_server_context
|
|||
// we have to evaluate at least 1 token to generate logits.
|
||||
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
||||
slot.n_past--;
|
||||
if (slot.ga_i > 0)
|
||||
{
|
||||
slot.n_past_se--;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
|
@ -1562,32 +1592,13 @@ struct llama_server_context
|
|||
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
|
||||
});
|
||||
|
||||
if (slot.ga_n != 1)
|
||||
{
|
||||
int ga_i = 0;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
int32_t slot_npast = 0;
|
||||
for (int k = 0; k < slot.n_past; ++k)
|
||||
{
|
||||
while (slot_npast >= ga_i + ga_w) {
|
||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
slot_npast -= bd;
|
||||
ga_i += ga_w/ga_n;
|
||||
}
|
||||
slot_npast++;
|
||||
}
|
||||
slot.n_past = slot_npast;
|
||||
slot.ga_i = ga_i;
|
||||
|
||||
LOG_TEE("slot %d : applied self-extend to prompt: %i tokens\n", slot.id, slot.n_past);
|
||||
}
|
||||
|
||||
const bool has_images = process_images(slot);
|
||||
|
||||
// process the prefix of first image
|
||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
||||
|
||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||
|
||||
int32_t ga_i = slot.ga_i;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
|
@ -1596,13 +1607,14 @@ struct llama_server_context
|
|||
{
|
||||
if (slot.ga_n != 1)
|
||||
{
|
||||
while (slot.n_past >= ga_i + ga_w) {
|
||||
while (slot_npast >= ga_i + ga_w) {
|
||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
slot.n_past -= bd;
|
||||
slot_npast -= bd;
|
||||
ga_i += ga_w/ga_n;
|
||||
}
|
||||
}
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id }, false);
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
|
||||
slot_npast++;
|
||||
}
|
||||
|
||||
if (has_images && !ingest_images(slot, n_batch))
|
||||
|
@ -1638,29 +1650,31 @@ struct llama_server_context
|
|||
if (slot.ga_n != 1)
|
||||
{
|
||||
// context extension via Self-Extend
|
||||
while (slot.n_past >= slot.ga_i + slot.ga_w)
|
||||
while (slot.n_past_se >= slot.ga_i + slot.ga_w)
|
||||
{
|
||||
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
||||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
||||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
||||
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past + ib * bd, dd);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
||||
|
||||
slot.n_past -= bd;
|
||||
slot.n_past_se -= bd;
|
||||
|
||||
slot.ga_i += slot.ga_w / slot.ga_n;
|
||||
|
||||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i);
|
||||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
||||
}
|
||||
slot.n_past_se += n_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch batch_view =
|
||||
{
|
||||
n_tokens,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue