llama: dbrx: fix experts 3D tensor layout (again)

This commit is contained in:
Pierrick HYMBERT 2024-04-08 19:37:23 +02:00
parent 18a84fedda
commit 9968952921

View file

@ -1522,16 +1522,16 @@ class DbrxModel(Model):
n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
n_embd = self.hparams["d_model"]
# Specific behavior for experts tensors: suffix .weight, reshape to 3D
# Specific behavior for experts tensors: suffix .weight, reshape to 3D and transpose
# orginal implementation expects (n_expert, n_ff, n_embd)
exp_tensor_names = {"ffn.experts.mlp.v1", # LLM_TENSOR_FFN_GATE_EXPS ne {n_embd, n_ff, n_expert}
"ffn.experts.mlp.w2", # LLM_TENSOR_FFN_DOWN_EXPS ne {n_ff, n_embd, n_expert}
"ffn.experts.mlp.w1"} # LLM_TENSOR_FFN_UP_EXPS ne {n_embd, n_ff, n_expert }
exp_tensor_names = {"ffn.experts.mlp.v1": (0, 1, 2), # LLM_TENSOR_FFN_GATE_EXPS(n_embd, n_ff, n_expert)
"ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS(n_ff, n_embd, n_expert)
"ffn.experts.mlp.w1": (0, 1, 2)} # LLM_TENSOR_FFN_UP_EXPS (n_embd, n_ff, n_expert)
experts = False
for exp_tensor_name in exp_tensor_names:
for exp_tensor_name in exp_tensor_names.keys():
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
experts = True
data_torch = data_torch.view(n_expert, n_ff, n_embd)
data_torch = data_torch.view(n_expert, n_ff, n_embd).permute(*exp_tensor_names[exp_tensor_name])
break
old_dtype = data_torch.dtype