llama : add llama_kv_cache_defrag

This commit is contained in:
Georgi Gerganov 2024-02-25 15:00:45 +02:00
parent 9ec749df59
commit 65f21ec5d3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 383 additions and 231 deletions

View file

@ -183,6 +183,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx); llama_kv_cache_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
@ -213,6 +214,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx); llama_kv_cache_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

195
llama.cpp
View file

@ -1722,6 +1722,7 @@ struct llama_kv_cell {
// ring-buffer of cached KV data // ring-buffer of cached KV data
struct llama_kv_cache { struct llama_kv_cache {
bool has_shift = false; bool has_shift = false;
bool do_defrag = false;
// Note: The value of head isn't only used to optimize searching // Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it // for a free KV slot. llama_decode_internal also uses it, so it
@ -2278,6 +2279,10 @@ static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos del
cache.compress_delta = delta; cache.compress_delta = delta;
} }
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
cache.do_defrag = true;
}
// //
// model loading and saving // model loading and saving
// //
@ -8029,29 +8034,7 @@ static int llama_decode_internal(
return 0; return 0;
} }
static void llama_kv_cache_update_internal(struct llama_context & lctx) { // summary:
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
llama_set_k_shift(lctx);
{
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
}
{
auto & kv_self = lctx.kv_self;
kv_self.has_shift = false;
for (uint32_t i = 0; i < kv_self.size; ++i) {
kv_self.cells[i].delta = 0;
}
}
}
// compress the KV cache data if needed:
// //
// - determine which KV cell pairs (i0, i1) to merge: // - determine which KV cell pairs (i0, i1) to merge:
// //
@ -8067,7 +8050,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
// //
// as a side effect, the new KV cache is defragmented // as a side effect, the new KV cache is defragmented
// //
if (lctx.kv_self.compress_delta >= 0) { static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self; auto & kv_self = lctx.kv_self;
const auto & hparams = lctx.model.hparams; const auto & hparams = lctx.model.hparams;
@ -8080,13 +8063,13 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv); const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
const uint32_t kv_size = kv_self.size; const uint32_t kv_size = kv_self.size;
const int64_t t_start = ggml_time_us();
std::vector<uint8_t> buf_q; std::vector<uint8_t> buf_q;
std::vector<float> buf_src_f32; std::vector<float> buf_src_f32;
std::vector<float> buf_dst_f32; std::vector<float> buf_dst_f32;
const int64_t t_start = ggml_time_us();
struct c_pair { uint32_t i0, i1; }; struct c_pair { uint32_t i0, i1; };
struct c_info { bool merged; uint32_t id, cnt, r; }; struct c_info { bool merged; uint32_t id, cnt, r; };
@ -8283,8 +8266,162 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
const int64_t t_end = ggml_time_us(); const int64_t t_end = ggml_time_us();
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0); LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
}
kv_self.compress_delta = -1; // copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
// removing any empty segments that may have been left by previous KV cache operations
// TODO: optimizations are possible:
// - multiple threads
// - avoid copying to the host memory when already there
// TODO: can we do all this on-device?
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self;
const auto & hparams = lctx.model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
const uint32_t kv_size = kv_self.size;
const int64_t t_start = ggml_time_us();
std::vector<uint8_t> buf_k;
std::vector<uint8_t> buf_v;
// the destination cell in the new KV cache
uint32_t id = 0;
// number of cells moved
uint32_t n_moves = 0;
// determine which KV cells to move where
std::vector<uint32_t> ids(n_kv, n_kv);
for (uint32_t i0 = 0; i0 < n_kv; ++i0) {
const auto & cell0 = kv_self.cells[i0];
if (!cell0.is_empty()) {
ids[i0] = id;
if (i0 != id) {
kv_self.cells[id] = cell0;
n_moves++;
}
id++;
}
}
if (n_moves == 0) {
return;
}
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
kv_self.head = id;
kv_self.used = id;
// zero the rest of the cells
for (uint32_t i = id; i < n_kv; ++i) {
kv_self.cells[i] = llama_kv_cell();
}
for (uint32_t il = 0; il < n_layer; ++il) {
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_size);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
const size_t v_size = ggml_row_size (kv_self.v_l[il]->type, n_embd_v_gqa*kv_size);
buf_k.resize(k_size);
buf_v.resize(v_size);
ggml_backend_tensor_get(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
ggml_backend_tensor_get(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
// batch move [i, i+nm) to [id, id+nm)
// note: cells can move only to a lower index
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t id = ids[i];
if (i == id || id == n_kv) {
continue;
}
uint32_t nm = 1;
while (i + nm < n_kv && ids[i + nm] == id + nm) {
nm++;
}
// move keys
{
const int64_t os = i*k_size_row;
const int64_t od = id*k_size_row;
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
}
// move values (note: they are transposed)
{
const int64_t os = i;
const int64_t od = id;
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
}
}
i += nm - 1;
}
ggml_backend_tensor_set(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
ggml_backend_tensor_set(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
}
const int64_t t_end = ggml_time_us();
LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
}
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
llama_set_k_shift(lctx);
{
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
}
{
auto & kv_self = lctx.kv_self;
kv_self.has_shift = false;
for (uint32_t i = 0; i < kv_self.size; ++i) {
kv_self.cells[i].delta = 0;
}
}
}
// compress the KV cache data if needed
if (lctx.kv_self.compress_delta >= 0) {
llama_kv_cache_compress_internal(lctx);
lctx.kv_self.compress_delta = -1;
lctx.kv_self.do_defrag = false;
}
// defragment the KV cache if needed
if (lctx.kv_self.do_defrag) {
llama_kv_cache_defrag_internal(lctx);
lctx.kv_self.do_defrag = false;
} }
} }
@ -12360,6 +12497,10 @@ void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
llama_kv_cache_compress(ctx->kv_self, delta); llama_kv_cache_compress(ctx->kv_self, delta);
} }
void llama_kv_cache_defrag(struct llama_context * ctx) {
llama_kv_cache_defrag(ctx->kv_self);
}
void llama_kv_cache_update(struct llama_context * ctx) { void llama_kv_cache_update(struct llama_context * ctx) {
llama_kv_cache_update_internal(*ctx); llama_kv_cache_update_internal(*ctx);
} }

11
llama.h
View file

@ -555,11 +555,20 @@ extern "C" {
llama_seq_id seq_id); llama_seq_id seq_id);
// [EXPERIMENTAL] Compress the data in the KV cache // [EXPERIMENTAL] Compress the data in the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
LLAMA_API void llama_kv_cache_compress( LLAMA_API void llama_kv_cache_compress(
struct llama_context * ctx, struct llama_context * ctx,
llama_pos delta); llama_pos delta);
// Apply the KV cache updates (such as K-shifts) to the KV data // Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx); LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
// //