Merge 76543311ac
into d7b31a9d84
This commit is contained in:
commit
8bb3cb5e9c
10 changed files with 203 additions and 30 deletions
|
@ -4136,6 +4136,28 @@ class DeepseekV2Model(Model):
|
|||
else:
|
||||
return []
|
||||
|
||||
if name.endswith("kv_b_proj.weight"):
|
||||
name_kb = name.replace("kv_b_proj", "k_b_proj")
|
||||
name_vb = name.replace("kv_b_proj", "v_b_proj")
|
||||
|
||||
n_head_kv = self.hparams["num_key_value_heads"]
|
||||
v_head_dim = self.hparams["v_head_dim"]
|
||||
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
|
||||
|
||||
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
|
||||
|
||||
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
|
||||
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
|
||||
k_b = k_b.transpose(1, 2)
|
||||
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
|
||||
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
|
||||
|
||||
return [
|
||||
(self.map_tensor_name(name), data_torch),
|
||||
(self.map_tensor_name(name_kb), k_b),
|
||||
(self.map_tensor_name(name_vb), v_b)
|
||||
]
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue