From edc2c08943aa4e1c4e1f3a65a391242eaef996c8 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Fri, 26 Jan 2024 22:53:34 +0100 Subject: [PATCH] Fixed prompt caching without self extend --- examples/server/server.cpp | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 195caf0b2..a606a9273 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1565,7 +1565,7 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = 0; + int32_t slot_npast = slot.n_past; int ga_i = slot.ga_i; int ga_n = slot.ga_n; int ga_w = slot.ga_w; @@ -1583,7 +1583,7 @@ struct llama_server_context slot_npast += 1; } - slot.n_past = 0; + if (has_images && !ingest_images(slot, n_batch)) { LOG_TEE("failed processing images\n"); @@ -1608,36 +1608,45 @@ struct llama_server_context return true; } + std::vector slot_npasts; + for (auto & slot : slots) + { + slot_npasts.emplace_back(slot.n_past); + } + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + int slot_id = 0; for (auto & slot : slots) { + if(slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past >= slot.ga_i + slot.ga_w) { + while (slot_npasts[slot_id] >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot_npasts[slot_id], ib * bd, slot.ga_i + ib * bd, slot_npasts[slot_id] + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot_npasts[slot_id] + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot_npasts[slot_id], ib * bd); llama_kv_cache_seq_div (ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd); - slot.n_past -= bd; + slot_npasts[slot_id] -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, 2n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot_npasts[slot_id] + bd, slot_npasts[slot_id], slot.ga_i); } } - slot.n_past += n_tokens; + slot_npasts[slot_id] += n_tokens; + slot_id++; } llama_batch batch_view =