llama : optimize DeepSeek MLA implementation

This commit is contained in:
Stanisław Szymczyk 2025-01-25 18:10:22 +01:00
parent f0ce53f158
commit de538aa329
10 changed files with 96 additions and 41 deletions

View file

@ -4136,6 +4136,29 @@ class DeepseekV2Model(Model):
else:
return []
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")
n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2);
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
return [
(self.map_tensor_name(name), data_torch),
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_vb), v_b)
]
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):