Fixed prompt caching without self extend
This commit is contained in:
parent
1f32360659
commit
edc2c08943
1 changed files with 19 additions and 10 deletions
|
@ -1565,7 +1565,7 @@ struct llama_server_context
|
|||
|
||||
// process the prefix of first image
|
||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
||||
int32_t slot_npast = 0;
|
||||
int32_t slot_npast = slot.n_past;
|
||||
int ga_i = slot.ga_i;
|
||||
int ga_n = slot.ga_n;
|
||||
int ga_w = slot.ga_w;
|
||||
|
@ -1583,7 +1583,7 @@ struct llama_server_context
|
|||
slot_npast += 1;
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
|
||||
if (has_images && !ingest_images(slot, n_batch))
|
||||
{
|
||||
LOG_TEE("failed processing images\n");
|
||||
|
@ -1608,36 +1608,45 @@ struct llama_server_context
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<int32_t> slot_npasts;
|
||||
for (auto & slot : slots)
|
||||
{
|
||||
slot_npasts.emplace_back(slot.n_past);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||
{
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
int slot_id = 0;
|
||||
for (auto & slot : slots)
|
||||
{
|
||||
|
||||
if(slot.ga_n != 1)
|
||||
{
|
||||
// context extension via Self-Extend
|
||||
while (slot.n_past >= slot.ga_i + slot.ga_w) {
|
||||
while (slot_npasts[slot_id] >= slot.ga_i + slot.ga_w) {
|
||||
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
||||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
||||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
||||
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot_npasts[slot_id], ib * bd, slot.ga_i + ib * bd, slot_npasts[slot_id] + ib * bd);
|
||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot_npasts[slot_id] + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot_npasts[slot_id], ib * bd);
|
||||
llama_kv_cache_seq_div (ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd);
|
||||
|
||||
slot.n_past -= bd;
|
||||
slot_npasts[slot_id] -= bd;
|
||||
|
||||
slot.ga_i += slot.ga_w / slot.ga_n;
|
||||
|
||||
LOG_TEE("\nn_past_old = %d, 2n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i);
|
||||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot_npasts[slot_id] + bd, slot_npasts[slot_id], slot.ga_i);
|
||||
}
|
||||
}
|
||||
slot.n_past += n_tokens;
|
||||
slot_npasts[slot_id] += n_tokens;
|
||||
slot_id++;
|
||||
}
|
||||
|
||||
llama_batch batch_view =
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue