llama : experimental DeepSeek2 MLA implementation that caches latent kv representations

This commit is contained in:
Stanisław Szymczyk 2025-01-22 15:19:34 +01:00
parent 6369f867a4
commit 93864cda8a
3 changed files with 106 additions and 16 deletions

View file

@ -53,7 +53,7 @@ bool llama_kv_cache_init(
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), /*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
}; };
@ -71,6 +71,10 @@ bool llama_kv_cache_init(
cache.k_l.reserve(n_layer); cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer); cache.v_l.reserve(n_layer);
// DeepSeek MLA
cache.kr_l.reserve(n_layer);
cache.kv_l.reserve(n_layer);
for (int i = 0; i < n_layer; i++) { for (int i = 0; i < n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@ -97,6 +101,16 @@ bool llama_kv_cache_init(
ggml_format_name(v, "cache_v_l%d", i); ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k); cache.k_l.push_back(k);
cache.v_l.push_back(v); cache.v_l.push_back(v);
// DeepSeek MLA
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
ggml_format_name(kr, "cache_kr_l%d", i);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.kr_l.push_back(kr);
cache.kv_l.push_back(kv);
} }
// allocate tensors and initialize the buffers to avoid NaNs in the padding // allocate tensors and initialize the buffers to avoid NaNs in the padding

View file

@ -49,11 +49,18 @@ struct llama_kv_cache {
ggml_type type_k = GGML_TYPE_F16; ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16;
ggml_type type_kr = GGML_TYPE_F32;
ggml_type type_kv = GGML_TYPE_F32;
std::vector<llama_kv_cell> cells; std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l; std::vector<struct ggml_tensor *> v_l;
// DeepSeek MLA
std::vector<struct ggml_tensor *> kr_l; // per layer
std::vector<struct ggml_tensor *> kv_l;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;

View file

@ -8860,32 +8860,37 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(kv_compressed, "kv_compressed", il); cb(kv_compressed, "kv_compressed", il);
struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
cb(kv_cache_view, "kv_cache_view", il);
// note: storing c^KV in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
struct ggml_tensor * kv_cache =
ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
0);
cb(kv_cache, "kv_cache", il);
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache);
cb(kv, "kv", il); cb(kv, "kv", il);
// split into {n_head * n_embd_head_qk_nope, n_tokens} // split into {n_head * n_embd_head_qk_nope, n_tokens}
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv,
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
0); 0);
cb(k_nope, "k_nope", il); cb(k_nope, "k_nope", il);
// and {n_head * n_embd_head_v, n_tokens} // and {n_head * n_embd_head_v, n_tokens}
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv,
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
ggml_row_size(kv->type, (n_embd_head_qk_nope))); ggml_row_size(kv->type, (n_embd_head_qk_nope)));
cb(v_states, "v_states", il); cb(v_states, "v_states", il);
v_states = ggml_cont(ctx0, v_states);
cb(v_states, "v_states", il);
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
0);
cb(v_states, "v_states", il);
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext( q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr, ctx0, q_pe, inp_pos, nullptr,
@ -8903,15 +8908,61 @@ struct llm_build_context {
); );
cb(k_pe, "k_pe", il); cb(k_pe, "k_pe", il);
struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
cb(kr_cache_view, "kr_cache_view", il);
// note: storing RoPE-ed version of K^R in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(q_states, "q_states", il); cb(q_states, "q_states", il);
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); struct ggml_tensor * kr_cache =
ggml_view_2d(ctx0, kv_self.kr_l[il],
n_embd_head_qk_rope, n_kv,
ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
0);
cb(kr_cache, "kr_cache", il);
// TODO is there a better way?
struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head);
struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape);
kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3);
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0);
cb(k_states, "k_states", il); cb(k_states, "k_states", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3);
model.layers[il].wo, NULL, cb(q_states, "q_states", il);
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3);
cb(k_states, "k_states", il);
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states);
cb(kq, "kq", il);
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3);
cb(v_states, "v_states", il);
v_states = ggml_cont(ctx0, v_states);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq);
cb(kqv, "kqv", il);
GGML_ASSERT(kv_self.size == n_ctx);
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
ggml_build_forward_expand(gf, cur);
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
cb(cur, "kqv_out", il);
} }
if (il == n_layer - 1) { if (il == n_layer - 1) {
@ -12004,6 +12055,24 @@ struct llama_context * llama_new_context_with_model(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
} }
{
size_t memory_size_kr = 0;
size_t memory_size_kv = 0;
for (auto & kr : ctx->kv_self.kr_l) {
memory_size_kr += ggml_nbytes(kr);
}
for (auto & kv : ctx->kv_self.kv_l) {
memory_size_kv += ggml_nbytes(kv);
}
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
(float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
}
// graph outputs buffer // graph outputs buffer
{ {
// resized during inference when a batch uses more outputs // resized during inference when a batch uses more outputs