From 9ec749df59982d84f7a5bafb8d08a2f4ca08f00f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 Feb 2024 13:57:43 +0200 Subject: [PATCH] llama : add alternative KV cache merging (EXPERIMENTAL) --- llama.cpp | 67 ++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1fb53f3db..2c05921bb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8072,10 +8072,13 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { const auto & hparams = lctx.model.hparams; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const uint32_t kv_size = kv_self.size; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k); + const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v); + const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv); + const uint32_t kv_size = kv_self.size; std::vector buf_q; @@ -8085,9 +8088,9 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { const int64_t t_start = ggml_time_us(); struct c_pair { uint32_t i0, i1; }; - struct c_info { bool merged; uint32_t id, cnt;}; + struct c_info { bool merged; uint32_t id, cnt, r; }; - std::vector infos(kv_size, { false, 0, 0 }); + std::vector infos(kv_size, { false, 0, 0, 0 }); // the destination cell in the new KV cache uint32_t id = 0; @@ -8100,7 +8103,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { const auto & cell0 = kv_self.cells[i0]; if (!cell0.is_empty() && !infos[i0].merged) { - infos[i0] = { true, id, 0 }; + infos[i0] = { true, id, 0, 0 }; infos[id].cnt = 1; const llama_pos p0 = cell0.pos; @@ -8112,7 +8115,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { const llama_pos p1 = cell1.pos; if (std::abs(p0 - p1) <= kv_self.compress_delta) { - infos[i1] = { true, id, 0 }; + infos[i1] = { true, id, 0, 0 }; infos[id].cnt++; n_merges++; } @@ -8143,6 +8146,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { tt_v = ggml_internal_get_type_traits(kv_self.type_v); for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t i = 0; i < kv_size; ++i) { + infos[i].r = 0; + } + // update keys { const int64_t ne = n_embd_k_gqa*kv_size; @@ -8167,6 +8174,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { const uint32_t id = infos[i].id; +#if 1 // merge using averaging { const float scale = 1.0f/float(infos[id].cnt); @@ -8178,6 +8186,25 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { buf_dst_f32[od + j] += buf_src_f32[os + j]*scale; } } +#else + // merge separate heads + { + for (uint32_t h = 0; h < n_head_kv; ++h) { + if ((h + il) % infos[id].cnt != infos[id].r) { + continue; + } + + const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k; + const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k; + + for (uint32_t j = 0; j < n_embd_head_k; ++j) { + buf_dst_f32[od + j] = buf_src_f32[os + j]; + } + } + } + + infos[id].r++; +#endif } tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne); @@ -8185,6 +8212,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size()); } + for (uint32_t i = 0; i < kv_size; ++i) { + infos[i].r = 0; + } + // update values (note: they are transposed) { const int64_t ne = n_embd_v_gqa*kv_size; @@ -8209,6 +8240,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { const uint32_t id = infos[i].id; +#if 1 // merge using averaging { const float scale = 1.0f/float(infos[id].cnt); @@ -8221,6 +8253,25 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale; } } +#else + // merge separate heads + { + for (uint32_t h = 0; h < n_head_kv; ++h) { + if ((h + il) % infos[id].cnt != infos[id].r) { + continue; + } + + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) { + buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size]; + } + } + } + + infos[id].r++; +#endif } tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);