llama : add a second copy of c^KV cache in DeepSeek2 MLA to avoid transposing the cache during inference
This commit is contained in:
parent
ce730637e8
commit
202f323e66
3 changed files with 19 additions and 4 deletions
|
@ -53,7 +53,7 @@ bool llama_kv_cache_init(
|
|||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()),
|
||||
/*.mem_size =*/ size_t(5u*n_layer*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
@ -74,6 +74,7 @@ bool llama_kv_cache_init(
|
|||
// DeepSeek MLA
|
||||
cache.kr_l.reserve(n_layer);
|
||||
cache.kv_l.reserve(n_layer);
|
||||
cache.kvt_l.reserve(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
|
@ -108,10 +109,13 @@ bool llama_kv_cache_init(
|
|||
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
|
||||
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
|
||||
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
|
||||
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
|
||||
ggml_format_name(kr, "cache_kr_l%d", i);
|
||||
ggml_format_name(kv, "cache_kv_l%d", i);
|
||||
ggml_format_name(kvt, "cache_kvt_l%d", i);
|
||||
cache.kr_l.push_back(kr);
|
||||
cache.kv_l.push_back(kv);
|
||||
cache.kvt_l.push_back(kvt);
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
|
|
|
@ -60,6 +60,7 @@ struct llama_kv_cache {
|
|||
// DeepSeek MLA
|
||||
std::vector<struct ggml_tensor *> kr_l; // per layer
|
||||
std::vector<struct ggml_tensor *> kv_l;
|
||||
std::vector<struct ggml_tensor *> kvt_l;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
|
|
@ -6476,6 +6476,12 @@ struct llm_build_context {
|
|||
// note: storing c^KV in the KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
|
||||
|
||||
struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
|
||||
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
|
||||
|
||||
// note: storing transposed c^KV in the transposed KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
|
||||
|
||||
struct ggml_tensor * kv_cache =
|
||||
ggml_view_2d(ctx0, kv_self.kv_l[il],
|
||||
kv_lora_rank, n_kv,
|
||||
|
@ -6483,6 +6489,13 @@ struct llm_build_context {
|
|||
0);
|
||||
cb(kv_cache, "kv_cache", il);
|
||||
|
||||
struct ggml_tensor * kv_cache_trans =
|
||||
ggml_view_2d(ctx0, kv_self.kvt_l[il],
|
||||
n_kv, kv_lora_rank,
|
||||
ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
|
||||
0);
|
||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
q_pe = ggml_rope_ext(
|
||||
ctx0, q_pe, inp_pos, nullptr,
|
||||
|
@ -6552,9 +6565,6 @@ struct llm_build_context {
|
|||
struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1);
|
||||
cb(kq_perm, "kq_soft_max_ext_perm", il);
|
||||
|
||||
struct ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache));
|
||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||
|
||||
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm);
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue