From fdfa5bc76b52b3551343d606069cb3107433f236 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 Feb 2024 11:00:19 +0200 Subject: [PATCH] llama : add llama_kv_cache_compress (EXPERIMENTAL) --- examples/passkey/passkey.cpp | 7 +- llama.cpp | 210 ++++++++++++++++++++++++++++++++++- llama.h | 5 + 3 files changed, 215 insertions(+), 7 deletions(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index a3a63977f..e2725aaa6 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -146,9 +146,10 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update (ctx); + llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_compress(ctx, 0); + llama_kv_cache_update (ctx); n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } diff --git a/llama.cpp b/llama.cpp index 0effc6db3..e90609089 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1733,6 +1733,12 @@ struct llama_kv_cache { // computed before each graph build uint32_t n = 0; + ggml_type type_k = GGML_TYPE_F16; + ggml_type type_v = GGML_TYPE_F16; + + // if non-negative, compress data on next update + llama_pos compress_delta = -1; + std::vector cells; std::vector k_l; // per layer @@ -1968,8 +1974,8 @@ struct llama_context { static bool llama_kv_cache_init( struct llama_kv_cache & cache, const llama_model & model, - ggml_type ktype, - ggml_type vtype, + ggml_type type_k, + ggml_type type_v, uint32_t n_ctx, bool offload) { const struct llama_hparams & hparams = model.hparams; @@ -1984,6 +1990,9 @@ static bool llama_kv_cache_init( cache.size = n_ctx; cache.used = 0; + cache.type_k = type_k; + cache.type_v = type_v; + cache.cells.clear(); cache.cells.resize(n_ctx); @@ -2024,8 +2033,8 @@ static bool llama_kv_cache_init( for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, ktype, n_embd_k_gqa*n_ctx); - ggml_tensor * v = ggml_new_tensor_1d(ctx, vtype, n_embd_v_gqa*n_ctx); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); @@ -2265,6 +2274,10 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama return result; } +static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) { + cache.compress_delta = delta; +} + // // model loading and saving // @@ -8034,6 +8047,191 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } } + + // compress the KV cache data if needed: + // + // - determine which KV cell pairs (i0, i1) to merge: + // + // abs(cell[i0].pos - cell[i1].pos) <= compress_delta + // + // - move the KV cache to the Host memory for easier maniiplation + // - processing is done layer-by-layer + // - convert the KV data to F32 + // - merge the KV data (different ways to merge) + // - convert the KV data back to the original type + // - move the KV cache back to the device memory + // - update the KV cache metadata + // + // as a side effect, the new KV cache is defragmented + // + if (lctx.kv_self.compress_delta >= 0) { + auto & kv_self = lctx.kv_self; + + const auto & hparams = lctx.model.hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t kv_size = kv_self.size; + + std::vector buf_q; + + std::vector buf_src_f32; + std::vector buf_dst_f32; + + const int64_t t_start = ggml_time_us(); + + struct c_pair { uint32_t i0, i1; }; + struct c_info { bool merged; uint32_t id, cnt;}; + + std::vector infos(kv_size, { false, 0, 0 }); + + // the destination cell in the new KV cache + uint32_t id = 0; + + // number of pairs merged + uint32_t n_merges = 0; + + // determine which KV cells to merge + for (uint32_t i0 = 0; i0 < kv_size; ++i0) { + const auto & cell0 = kv_self.cells[i0]; + + if (!cell0.is_empty() && !infos[i0].merged) { + infos[i0] = { true, id, 0 }; + infos[id].cnt = 1; + + const llama_pos p0 = cell0.pos; + + for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) { + const auto & cell1 = kv_self.cells[i1]; + + if (i0 != i1 && cell0.is_same_seq(cell1)) { + const llama_pos p1 = cell1.pos; + + if (std::abs(p0 - p1) <= kv_self.compress_delta) { + infos[i1] = { true, id, 0 }; + infos[id].cnt++; + n_merges++; + } + } + } + + if (i0 != id) { + kv_self.cells[id] = cell0; + } + + id++; + } + } + + kv_self.head = id; + kv_self.used = id; + + for (uint32_t i = id; i < kv_size; ++i) { + kv_self.cells[i] = llama_kv_cell(); + } + + LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges); + + ggml_type_traits_t tt_k; + ggml_type_traits_t tt_v; + + tt_k = ggml_internal_get_type_traits(kv_self.type_k); + tt_v = ggml_internal_get_type_traits(kv_self.type_v); + + for (uint32_t il = 0; il < n_layer; ++il) { + // update keys + { + const int64_t ne = n_embd_k_gqa*kv_size; + + const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne); + + buf_q.resize(k_size); + + buf_src_f32.resize(ne); + buf_dst_f32.resize(ne); + + ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size()); + + tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne); + + std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0); + + for (uint32_t i = 0; i < kv_size; ++i) { + if (!infos[i].merged) { + continue; + } + + const uint32_t id = infos[i].id; + + // merge using averaging + { + const float scale = 1.0f/float(infos[id].cnt); + + const int64_t os = i*n_embd_k_gqa; + const int64_t od = id*n_embd_k_gqa; + + for (uint32_t j = 0; j < n_embd_k_gqa; ++j) { + buf_dst_f32[od + j] += buf_src_f32[os + j]*scale; + } + } + } + + tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne); + + ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size()); + } + + // update values (note: they are transposed) + { + const int64_t ne = n_embd_v_gqa*kv_size; + + const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne); + + buf_q.resize(v_size); + + buf_src_f32.resize(ne); + buf_dst_f32.resize(ne); + + ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size()); + + tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne); + + std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0); + + for (uint32_t i = 0; i < kv_size; ++i) { + if (!infos[i].merged) { + continue; + } + + const uint32_t id = infos[i].id; + + // merge using averaging + { + const float scale = 1.0f/float(infos[id].cnt); + //printf("i: %d -> id: %d, scale: %f\n", i, id, scale); + + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale; + } + } + } + + tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne); + + ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size()); + } + } + + const int64_t t_end = ggml_time_us(); + + LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0); + + kv_self.compress_delta = -1; + } } // @@ -12083,6 +12281,10 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); } +void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) { + llama_kv_cache_compress(ctx->kv_self, delta); +} + void llama_kv_cache_update(struct llama_context * ctx) { llama_kv_cache_update_internal(*ctx); } diff --git a/llama.h b/llama.h index faea891e4..3fac7b79c 100644 --- a/llama.h +++ b/llama.h @@ -552,6 +552,11 @@ extern "C" { struct llama_context * ctx, llama_seq_id seq_id); + // [EXPERIMENTAL] Compress the data in the KV cache + LLAMA_API void llama_kv_cache_compress( + struct llama_context * ctx, + llama_pos delta); + // Apply the KV cache updates (such as K-shifts) to the KV data LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);