From 65f21ec5d3e774978765f4de82231809c2cc3e72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 Feb 2024 15:00:45 +0200 Subject: [PATCH] llama : add llama_kv_cache_defrag --- examples/passkey/passkey.cpp | 2 + llama.cpp | 601 +++++++++++++++++++++-------------- llama.h | 11 +- 3 files changed, 383 insertions(+), 231 deletions(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index e2725aaa6..4c8a04135 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -183,6 +183,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_defrag (ctx); llama_kv_cache_update (ctx); n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; @@ -213,6 +214,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_defrag (ctx); llama_kv_cache_update (ctx); n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; diff --git a/llama.cpp b/llama.cpp index 2c05921bb..61539b24a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1722,6 +1722,7 @@ struct llama_kv_cell { // ring-buffer of cached KV data struct llama_kv_cache { bool has_shift = false; + bool do_defrag = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2278,6 +2279,10 @@ static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos del cache.compress_delta = delta; } +static void llama_kv_cache_defrag(struct llama_kv_cache & cache) { + cache.do_defrag = true; +} + // // model loading and saving // @@ -8029,6 +8034,359 @@ static int llama_decode_internal( return 0; } +// summary: +// +// - determine which KV cell pairs (i0, i1) to merge: +// +// abs(cell[i0].pos - cell[i1].pos) <= compress_delta +// +// - move the KV cache to the Host memory for easier maniiplation +// - processing is done layer-by-layer +// - convert the KV data to F32 +// - merge the KV data (different ways to merge) +// - convert the KV data back to the original type +// - move the KV cache back to the device memory +// - update the KV cache metadata +// +// as a side effect, the new KV cache is defragmented +// +static void llama_kv_cache_compress_internal(struct llama_context & lctx) { + auto & kv_self = lctx.kv_self; + + const auto & hparams = lctx.model.hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k); + const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v); + const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv); + const uint32_t kv_size = kv_self.size; + + const int64_t t_start = ggml_time_us(); + + std::vector buf_q; + + std::vector buf_src_f32; + std::vector buf_dst_f32; + + struct c_pair { uint32_t i0, i1; }; + struct c_info { bool merged; uint32_t id, cnt, r; }; + + std::vector infos(kv_size, { false, 0, 0, 0 }); + + // the destination cell in the new KV cache + uint32_t id = 0; + + // number of pairs merged + uint32_t n_merges = 0; + + // determine which KV cells to merge + for (uint32_t i0 = 0; i0 < kv_size; ++i0) { + const auto & cell0 = kv_self.cells[i0]; + + if (!cell0.is_empty() && !infos[i0].merged) { + infos[i0] = { true, id, 0, 0 }; + infos[id].cnt = 1; + + const llama_pos p0 = cell0.pos; + + for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) { + const auto & cell1 = kv_self.cells[i1]; + + if (i0 != i1 && cell0.is_same_seq(cell1)) { + const llama_pos p1 = cell1.pos; + + if (std::abs(p0 - p1) <= kv_self.compress_delta) { + infos[i1] = { true, id, 0, 0 }; + infos[id].cnt++; + n_merges++; + } + } + } + + if (i0 != id) { + kv_self.cells[id] = cell0; + } + + id++; + } + } + + kv_self.head = id; + kv_self.used = id; + + for (uint32_t i = id; i < kv_size; ++i) { + kv_self.cells[i] = llama_kv_cell(); + } + + LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges); + + ggml_type_traits_t tt_k; + ggml_type_traits_t tt_v; + + tt_k = ggml_internal_get_type_traits(kv_self.type_k); + tt_v = ggml_internal_get_type_traits(kv_self.type_v); + + for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t i = 0; i < kv_size; ++i) { + infos[i].r = 0; + } + + // update keys + { + const int64_t ne = n_embd_k_gqa*kv_size; + + const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne); + + buf_q.resize(k_size); + + buf_src_f32.resize(ne); + buf_dst_f32.resize(ne); + + ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size()); + + tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne); + + std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0); + + for (uint32_t i = 0; i < kv_size; ++i) { + if (!infos[i].merged) { + continue; + } + + const uint32_t id = infos[i].id; + +#if 1 + // merge using averaging + { + const float scale = 1.0f/float(infos[id].cnt); + + const int64_t os = i*n_embd_k_gqa; + const int64_t od = id*n_embd_k_gqa; + + for (uint32_t j = 0; j < n_embd_k_gqa; ++j) { + buf_dst_f32[od + j] += buf_src_f32[os + j]*scale; + } + } +#else + // merge separate heads + { + for (uint32_t h = 0; h < n_head_kv; ++h) { + if ((h + il) % infos[id].cnt != infos[id].r) { + continue; + } + + const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k; + const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k; + + for (uint32_t j = 0; j < n_embd_head_k; ++j) { + buf_dst_f32[od + j] = buf_src_f32[os + j]; + } + } + } + + infos[id].r++; +#endif + } + + tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne); + + ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size()); + } + + for (uint32_t i = 0; i < kv_size; ++i) { + infos[i].r = 0; + } + + // update values (note: they are transposed) + { + const int64_t ne = n_embd_v_gqa*kv_size; + + const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne); + + buf_q.resize(v_size); + + buf_src_f32.resize(ne); + buf_dst_f32.resize(ne); + + ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size()); + + tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne); + + std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0); + + for (uint32_t i = 0; i < kv_size; ++i) { + if (!infos[i].merged) { + continue; + } + + const uint32_t id = infos[i].id; + +#if 1 + // merge using averaging + { + const float scale = 1.0f/float(infos[id].cnt); + //printf("i: %d -> id: %d, scale: %f\n", i, id, scale); + + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale; + } + } +#else + // merge separate heads + { + for (uint32_t h = 0; h < n_head_kv; ++h) { + if ((h + il) % infos[id].cnt != infos[id].r) { + continue; + } + + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) { + buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size]; + } + } + } + + infos[id].r++; +#endif + } + + tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne); + + ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size()); + } + } + + const int64_t t_end = ggml_time_us(); + + LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0); +} + +// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache +// removing any empty segments that may have been left by previous KV cache operations +// TODO: optimizations are possible: +// - multiple threads +// - avoid copying to the host memory when already there +// TODO: can we do all this on-device? +static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { + auto & kv_self = lctx.kv_self; + + const auto & hparams = lctx.model.hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_kv = llama_kv_cache_cell_max(kv_self); + + const uint32_t kv_size = kv_self.size; + + const int64_t t_start = ggml_time_us(); + + std::vector buf_k; + std::vector buf_v; + + // the destination cell in the new KV cache + uint32_t id = 0; + + // number of cells moved + uint32_t n_moves = 0; + + // determine which KV cells to move where + std::vector ids(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_kv; ++i0) { + const auto & cell0 = kv_self.cells[i0]; + + if (!cell0.is_empty()) { + ids[i0] = id; + + if (i0 != id) { + kv_self.cells[id] = cell0; + n_moves++; + } + + id++; + } + } + + if (n_moves == 0) { + return; + } + + LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves); + + kv_self.head = id; + kv_self.used = id; + + // zero the rest of the cells + for (uint32_t i = id; i < n_kv; ++i) { + kv_self.cells[i] = llama_kv_cell(); + } + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + const size_t v_size = ggml_row_size (kv_self.v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(kv_self.k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(kv_self.v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(kv_self.k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(kv_self.v_l[il], buf_v.data(), 0, buf_v.size()); + } + + const int64_t t_end = ggml_time_us(); + + LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0); +} + static void llama_kv_cache_update_internal(struct llama_context & lctx) { // apply K-shift if needed if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { @@ -8051,240 +8409,19 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - // compress the KV cache data if needed: - // - // - determine which KV cell pairs (i0, i1) to merge: - // - // abs(cell[i0].pos - cell[i1].pos) <= compress_delta - // - // - move the KV cache to the Host memory for easier maniiplation - // - processing is done layer-by-layer - // - convert the KV data to F32 - // - merge the KV data (different ways to merge) - // - convert the KV data back to the original type - // - move the KV cache back to the device memory - // - update the KV cache metadata - // - // as a side effect, the new KV cache is defragmented - // + // compress the KV cache data if needed if (lctx.kv_self.compress_delta >= 0) { - auto & kv_self = lctx.kv_self; + llama_kv_cache_compress_internal(lctx); - const auto & hparams = lctx.model.hparams; + lctx.kv_self.compress_delta = -1; + lctx.kv_self.do_defrag = false; + } - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k); - const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v); - const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv); - const uint32_t kv_size = kv_self.size; + // defragment the KV cache if needed + if (lctx.kv_self.do_defrag) { + llama_kv_cache_defrag_internal(lctx); - std::vector buf_q; - - std::vector buf_src_f32; - std::vector buf_dst_f32; - - const int64_t t_start = ggml_time_us(); - - struct c_pair { uint32_t i0, i1; }; - struct c_info { bool merged; uint32_t id, cnt, r; }; - - std::vector infos(kv_size, { false, 0, 0, 0 }); - - // the destination cell in the new KV cache - uint32_t id = 0; - - // number of pairs merged - uint32_t n_merges = 0; - - // determine which KV cells to merge - for (uint32_t i0 = 0; i0 < kv_size; ++i0) { - const auto & cell0 = kv_self.cells[i0]; - - if (!cell0.is_empty() && !infos[i0].merged) { - infos[i0] = { true, id, 0, 0 }; - infos[id].cnt = 1; - - const llama_pos p0 = cell0.pos; - - for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) { - const auto & cell1 = kv_self.cells[i1]; - - if (i0 != i1 && cell0.is_same_seq(cell1)) { - const llama_pos p1 = cell1.pos; - - if (std::abs(p0 - p1) <= kv_self.compress_delta) { - infos[i1] = { true, id, 0, 0 }; - infos[id].cnt++; - n_merges++; - } - } - } - - if (i0 != id) { - kv_self.cells[id] = cell0; - } - - id++; - } - } - - kv_self.head = id; - kv_self.used = id; - - for (uint32_t i = id; i < kv_size; ++i) { - kv_self.cells[i] = llama_kv_cell(); - } - - LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges); - - ggml_type_traits_t tt_k; - ggml_type_traits_t tt_v; - - tt_k = ggml_internal_get_type_traits(kv_self.type_k); - tt_v = ggml_internal_get_type_traits(kv_self.type_v); - - for (uint32_t il = 0; il < n_layer; ++il) { - for (uint32_t i = 0; i < kv_size; ++i) { - infos[i].r = 0; - } - - // update keys - { - const int64_t ne = n_embd_k_gqa*kv_size; - - const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne); - - buf_q.resize(k_size); - - buf_src_f32.resize(ne); - buf_dst_f32.resize(ne); - - ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size()); - - tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne); - - std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0); - - for (uint32_t i = 0; i < kv_size; ++i) { - if (!infos[i].merged) { - continue; - } - - const uint32_t id = infos[i].id; - -#if 1 - // merge using averaging - { - const float scale = 1.0f/float(infos[id].cnt); - - const int64_t os = i*n_embd_k_gqa; - const int64_t od = id*n_embd_k_gqa; - - for (uint32_t j = 0; j < n_embd_k_gqa; ++j) { - buf_dst_f32[od + j] += buf_src_f32[os + j]*scale; - } - } -#else - // merge separate heads - { - for (uint32_t h = 0; h < n_head_kv; ++h) { - if ((h + il) % infos[id].cnt != infos[id].r) { - continue; - } - - const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k; - const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k; - - for (uint32_t j = 0; j < n_embd_head_k; ++j) { - buf_dst_f32[od + j] = buf_src_f32[os + j]; - } - } - } - - infos[id].r++; -#endif - } - - tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne); - - ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size()); - } - - for (uint32_t i = 0; i < kv_size; ++i) { - infos[i].r = 0; - } - - // update values (note: they are transposed) - { - const int64_t ne = n_embd_v_gqa*kv_size; - - const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne); - - buf_q.resize(v_size); - - buf_src_f32.resize(ne); - buf_dst_f32.resize(ne); - - ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size()); - - tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne); - - std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0); - - for (uint32_t i = 0; i < kv_size; ++i) { - if (!infos[i].merged) { - continue; - } - - const uint32_t id = infos[i].id; - -#if 1 - // merge using averaging - { - const float scale = 1.0f/float(infos[id].cnt); - //printf("i: %d -> id: %d, scale: %f\n", i, id, scale); - - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale; - } - } -#else - // merge separate heads - { - for (uint32_t h = 0; h < n_head_kv; ++h) { - if ((h + il) % infos[id].cnt != infos[id].r) { - continue; - } - - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) { - buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size]; - } - } - } - - infos[id].r++; -#endif - } - - tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne); - - ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size()); - } - } - - const int64_t t_end = ggml_time_us(); - - LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0); - - kv_self.compress_delta = -1; + lctx.kv_self.do_defrag = false; } } @@ -12360,6 +12497,10 @@ void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) { llama_kv_cache_compress(ctx->kv_self, delta); } +void llama_kv_cache_defrag(struct llama_context * ctx) { + llama_kv_cache_defrag(ctx->kv_self); +} + void llama_kv_cache_update(struct llama_context * ctx) { llama_kv_cache_update_internal(*ctx); } diff --git a/llama.h b/llama.h index 8f959824f..862d555e2 100644 --- a/llama.h +++ b/llama.h @@ -555,11 +555,20 @@ extern "C" { llama_seq_id seq_id); // [EXPERIMENTAL] Compress the data in the KV cache + // This will be applied: + // - lazily on next llama_decode() + // - explicitly with llama_kv_cache_update() LLAMA_API void llama_kv_cache_compress( struct llama_context * ctx, llama_pos delta); - // Apply the KV cache updates (such as K-shifts) to the KV data + // Defragment the KV cache + // This will be applied: + // - lazily on next llama_decode() + // - explicitly with llama_kv_cache_update() + LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx); + + // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) LLAMA_API void llama_kv_cache_update(struct llama_context * ctx); //