llama : optimize DeepSeek MLA implementation
This commit is contained in:
parent
f0ce53f158
commit
de538aa329
10 changed files with 96 additions and 41 deletions
|
@ -4136,6 +4136,29 @@ class DeepseekV2Model(Model):
|
|||
else:
|
||||
return []
|
||||
|
||||
if name.endswith("kv_b_proj.weight"):
|
||||
name_kb = name.replace("kv_b_proj", "k_b_proj")
|
||||
name_vb = name.replace("kv_b_proj", "v_b_proj")
|
||||
|
||||
n_head_kv = self.hparams["num_key_value_heads"]
|
||||
v_head_dim = self.hparams["v_head_dim"]
|
||||
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
|
||||
|
||||
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
|
||||
|
||||
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
|
||||
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
|
||||
k_b = k_b.transpose(1, 2);
|
||||
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
|
||||
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
|
||||
|
||||
return [
|
||||
(self.map_tensor_name(name), data_torch),
|
||||
(self.map_tensor_name(name_kb), k_b),
|
||||
(self.map_tensor_name(name_vb), v_b)
|
||||
]
|
||||
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
|
|
|
@ -356,6 +356,8 @@ class MODEL_TENSOR(IntEnum):
|
|||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_K_B = auto()
|
||||
ATTN_V_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
|
@ -543,6 +545,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
|
||||
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
|
@ -1333,6 +1337,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_K_B,
|
||||
MODEL_TENSOR.ATTN_V_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
|
|
|
@ -586,6 +586,14 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_B: (
|
||||
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_V_B: (
|
||||
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
),
|
||||
|
|
|
@ -999,6 +999,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
|
||||
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
|
@ -1330,6 +1332,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
@ -1347,6 +1351,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
|
|
@ -277,6 +277,8 @@ enum llm_tensor {
|
|||
LLM_TENSOR_ATTN_Q_B,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_K_B,
|
||||
LLM_TENSOR_ATTN_V_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
LLM_TENSOR_ATTN_SUB_NORM,
|
||||
|
|
|
@ -105,6 +105,7 @@ bool llama_kv_cache_init(
|
|||
// DeepSeek MLA
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
|
||||
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
|
||||
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
|
||||
ggml_format_name(kr, "cache_kr_l%d", i);
|
||||
|
|
|
@ -49,8 +49,8 @@ struct llama_kv_cache {
|
|||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
ggml_type type_kr = GGML_TYPE_F32;
|
||||
ggml_type type_kv = GGML_TYPE_F32;
|
||||
ggml_type type_kr = GGML_TYPE_F16;
|
||||
ggml_type type_kv = GGML_TYPE_F16;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
|
|
|
@ -2870,6 +2870,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
|
||||
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
|
||||
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
|
||||
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
|
||||
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
|
|
@ -161,6 +161,8 @@ struct llama_layer {
|
|||
struct ggml_tensor * wq_b = nullptr;
|
||||
struct ggml_tensor * wkv_a_mqa = nullptr;
|
||||
struct ggml_tensor * wkv_b = nullptr;
|
||||
struct ggml_tensor * wk_b = nullptr;
|
||||
struct ggml_tensor * wv_b = nullptr;
|
||||
struct ggml_tensor * wq_cross = nullptr;
|
||||
struct ggml_tensor * wk_cross = nullptr;
|
||||
struct ggml_tensor * wv_cross = nullptr;
|
||||
|
|
|
@ -6483,24 +6483,6 @@ struct llm_build_context {
|
|||
0);
|
||||
cb(kv_cache, "kv_cache", il);
|
||||
|
||||
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache);
|
||||
cb(kv, "kv", il);
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
0);
|
||||
cb(k_nope, "k_nope", il);
|
||||
|
||||
// and {n_head * n_embd_head_v, n_tokens}
|
||||
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv,
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
q_pe = ggml_rope_ext(
|
||||
ctx0, q_pe, inp_pos, nullptr,
|
||||
|
@ -6524,9 +6506,6 @@ struct llm_build_context {
|
|||
// note: storing RoPE-ed version of K^R in the KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
|
||||
|
||||
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(q_states, "q_states", il);
|
||||
|
||||
struct ggml_tensor * kr_cache =
|
||||
ggml_view_2d(ctx0, kv_self.kr_l[il],
|
||||
n_embd_head_qk_rope, n_kv,
|
||||
|
@ -6534,36 +6513,62 @@ struct llm_build_context {
|
|||
0);
|
||||
cb(kr_cache, "kr_cache", il);
|
||||
|
||||
// TODO is there a better way?
|
||||
struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head);
|
||||
struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape);
|
||||
kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3);
|
||||
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0);
|
||||
cb(k_states, "k_states", il);
|
||||
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
|
||||
cb(wk_b, "wk_b", il);
|
||||
|
||||
q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3);
|
||||
cb(q_states, "q_states", il);
|
||||
struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 3, 1);
|
||||
cb(q_nope_perm, "q_nope_perm", il);
|
||||
|
||||
k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3);
|
||||
cb(k_states, "k_states", il);
|
||||
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm);
|
||||
cb(q_nope2, "q_nope2", il);
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states);
|
||||
cb(kq, "kq", il);
|
||||
struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2);
|
||||
cb(q_nope2_perm, "q_nope2_perm", il);
|
||||
|
||||
struct ggml_tensor * kv_cache_perm = ggml_cont(ctx0, ggml_permute(ctx0, kv_cache, 1, 0, 2, 3));
|
||||
cb(kv_cache_perm, "kv_cache_perm", il);
|
||||
|
||||
struct ggml_tensor * scores1 = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
|
||||
cb(scores1, "scores1", il);
|
||||
|
||||
struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
|
||||
cb(q_pe_perm, "q_pe_perm", il);
|
||||
|
||||
struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1);
|
||||
cb(kr_cache_perm, "kr_cache_perm", il);
|
||||
|
||||
struct ggml_tensor * scores2 = ggml_mul_mat(ctx0, kr_cache, q_pe_perm);
|
||||
cb(scores2, "scores2", il);
|
||||
|
||||
struct ggml_tensor * scores = ggml_add(ctx0, scores1, scores2);
|
||||
cb(scores, "scores", il);
|
||||
|
||||
struct ggml_tensor * kq = ggml_permute(ctx0, scores, 0, 3, 1, 2);
|
||||
|
||||
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
|
||||
cb(wv_b, "wv_b", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3);
|
||||
cb(v_states, "v_states", il);
|
||||
struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1);
|
||||
cb(kq_perm, "kq_perm", il);
|
||||
|
||||
v_states = ggml_cont(ctx0, v_states);
|
||||
struct ggml_tensor * kqv1 = ggml_mul_mat(ctx0, kv_cache_perm, kq_perm);
|
||||
cb(kqv1, "kqv1", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
struct ggml_tensor * kqv1_trans = ggml_permute(ctx0, kqv1, 0, 1, 3, 2);
|
||||
cb(kqv1_trans, "kqv1_trans", il);
|
||||
|
||||
struct ggml_tensor * kqv2 = ggml_mul_mat(ctx0, wv_b, kqv1_trans);
|
||||
cb(kqv2, "kqv2", il);
|
||||
|
||||
struct ggml_tensor * kqv2_trans = ggml_permute(ctx0, kqv2, 0, 3, 2, 1);
|
||||
cb(kqv2_trans, "kqv2_trans", il);
|
||||
|
||||
GGML_ASSERT(kv_self.size == n_ctx);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv2_trans, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue