llama : experimental DeepSeek2 MLA implementation that caches latent kv representations
This commit is contained in:
parent
6369f867a4
commit
93864cda8a
3 changed files with 106 additions and 16 deletions
|
@ -53,7 +53,7 @@ bool llama_kv_cache_init(
|
||||||
auto it = ctx_map.find(buft);
|
auto it = ctx_map.find(buft);
|
||||||
if (it == ctx_map.end()) {
|
if (it == ctx_map.end()) {
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
/*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
@ -71,6 +71,10 @@ bool llama_kv_cache_init(
|
||||||
cache.k_l.reserve(n_layer);
|
cache.k_l.reserve(n_layer);
|
||||||
cache.v_l.reserve(n_layer);
|
cache.v_l.reserve(n_layer);
|
||||||
|
|
||||||
|
// DeepSeek MLA
|
||||||
|
cache.kr_l.reserve(n_layer);
|
||||||
|
cache.kv_l.reserve(n_layer);
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; i++) {
|
for (int i = 0; i < n_layer; i++) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||||
|
@ -97,6 +101,16 @@ bool llama_kv_cache_init(
|
||||||
ggml_format_name(v, "cache_v_l%d", i);
|
ggml_format_name(v, "cache_v_l%d", i);
|
||||||
cache.k_l.push_back(k);
|
cache.k_l.push_back(k);
|
||||||
cache.v_l.push_back(v);
|
cache.v_l.push_back(v);
|
||||||
|
|
||||||
|
// DeepSeek MLA
|
||||||
|
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||||
|
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||||
|
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
|
||||||
|
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
|
||||||
|
ggml_format_name(kr, "cache_kr_l%d", i);
|
||||||
|
ggml_format_name(kv, "cache_kv_l%d", i);
|
||||||
|
cache.kr_l.push_back(kr);
|
||||||
|
cache.kv_l.push_back(kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||||
|
|
|
@ -49,11 +49,18 @@ struct llama_kv_cache {
|
||||||
ggml_type type_k = GGML_TYPE_F16;
|
ggml_type type_k = GGML_TYPE_F16;
|
||||||
ggml_type type_v = GGML_TYPE_F16;
|
ggml_type type_v = GGML_TYPE_F16;
|
||||||
|
|
||||||
|
ggml_type type_kr = GGML_TYPE_F32;
|
||||||
|
ggml_type type_kv = GGML_TYPE_F32;
|
||||||
|
|
||||||
std::vector<llama_kv_cell> cells;
|
std::vector<llama_kv_cell> cells;
|
||||||
|
|
||||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||||
std::vector<struct ggml_tensor *> v_l;
|
std::vector<struct ggml_tensor *> v_l;
|
||||||
|
|
||||||
|
// DeepSeek MLA
|
||||||
|
std::vector<struct ggml_tensor *> kr_l; // per layer
|
||||||
|
std::vector<struct ggml_tensor *> kv_l;
|
||||||
|
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
|
|
||||||
|
|
|
@ -8860,32 +8860,37 @@ struct llm_build_context {
|
||||||
LLM_NORM_RMS, cb, il);
|
LLM_NORM_RMS, cb, il);
|
||||||
cb(kv_compressed, "kv_compressed", il);
|
cb(kv_compressed, "kv_compressed", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
|
||||||
|
cb(kv_cache_view, "kv_cache_view", il);
|
||||||
|
|
||||||
|
// note: storing c^KV in the KV cache
|
||||||
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
|
||||||
|
|
||||||
|
struct ggml_tensor * kv_cache =
|
||||||
|
ggml_view_2d(ctx0, kv_self.kv_l[il],
|
||||||
|
kv_lora_rank, n_kv,
|
||||||
|
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
|
||||||
|
0);
|
||||||
|
cb(kv_cache, "kv_cache", il);
|
||||||
|
|
||||||
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
||||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
|
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache);
|
||||||
cb(kv, "kv", il);
|
cb(kv, "kv", il);
|
||||||
|
|
||||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv,
|
||||||
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||||
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||||
0);
|
0);
|
||||||
cb(k_nope, "k_nope", il);
|
cb(k_nope, "k_nope", il);
|
||||||
|
|
||||||
// and {n_head * n_embd_head_v, n_tokens}
|
// and {n_head * n_embd_head_v, n_tokens}
|
||||||
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv,
|
||||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
||||||
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
||||||
cb(v_states, "v_states", il);
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
v_states = ggml_cont(ctx0, v_states);
|
|
||||||
cb(v_states, "v_states", il);
|
|
||||||
|
|
||||||
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
|
||||||
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
|
||||||
0);
|
|
||||||
cb(v_states, "v_states", il);
|
|
||||||
|
|
||||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
q_pe = ggml_rope_ext(
|
q_pe = ggml_rope_ext(
|
||||||
ctx0, q_pe, inp_pos, nullptr,
|
ctx0, q_pe, inp_pos, nullptr,
|
||||||
|
@ -8903,15 +8908,61 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(k_pe, "k_pe", il);
|
cb(k_pe, "k_pe", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
|
||||||
|
cb(kr_cache_view, "kr_cache_view", il);
|
||||||
|
|
||||||
|
// note: storing RoPE-ed version of K^R in the KV cache
|
||||||
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
|
||||||
|
|
||||||
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||||
cb(q_states, "q_states", il);
|
cb(q_states, "q_states", il);
|
||||||
|
|
||||||
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
struct ggml_tensor * kr_cache =
|
||||||
|
ggml_view_2d(ctx0, kv_self.kr_l[il],
|
||||||
|
n_embd_head_qk_rope, n_kv,
|
||||||
|
ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
|
||||||
|
0);
|
||||||
|
cb(kr_cache, "kr_cache", il);
|
||||||
|
|
||||||
|
// TODO is there a better way?
|
||||||
|
struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head);
|
||||||
|
struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape);
|
||||||
|
kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3);
|
||||||
|
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0);
|
||||||
cb(k_states, "k_states", il);
|
cb(k_states, "k_states", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3);
|
||||||
model.layers[il].wo, NULL,
|
cb(q_states, "q_states", il);
|
||||||
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
||||||
|
k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3);
|
||||||
|
cb(k_states, "k_states", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states);
|
||||||
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
|
||||||
|
v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3);
|
||||||
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
|
v_states = ggml_cont(ctx0, v_states);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq);
|
||||||
|
cb(kqv, "kqv", il);
|
||||||
|
|
||||||
|
GGML_ASSERT(kv_self.size == n_ctx);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||||
|
cb(kqv_merged, "kqv_merged", il);
|
||||||
|
|
||||||
|
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||||
|
cb(cur, "kqv_merged_cont", il);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -12004,6 +12055,24 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
size_t memory_size_kr = 0;
|
||||||
|
size_t memory_size_kv = 0;
|
||||||
|
|
||||||
|
for (auto & kr : ctx->kv_self.kr_l) {
|
||||||
|
memory_size_kr += ggml_nbytes(kr);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & kv : ctx->kv_self.kv_l) {
|
||||||
|
memory_size_kv += ggml_nbytes(kv);
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
|
||||||
|
(float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
|
||||||
|
ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
|
||||||
|
ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
|
||||||
|
}
|
||||||
|
|
||||||
// graph outputs buffer
|
// graph outputs buffer
|
||||||
{
|
{
|
||||||
// resized during inference when a batch uses more outputs
|
// resized during inference when a batch uses more outputs
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue