model: dbrx: fix expert reshape

This commit is contained in:
Pierrick HYMBERT 2024-04-07 19:38:35 +02:00
parent 7dd84b0924
commit c9bddbf253

View file

@ -1516,12 +1516,19 @@ class DbrxModel(Model):
block_count = self.hparams.get("n_layers") block_count = self.hparams.get("n_layers")
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors(): for name, data_torch in self.get_tensors():
n_expert = self.hparams["ffn_config"]["moe_num_experts"]
n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
n_embd = self.hparams["d_model"]
# Specific behavior for experts tensors: reshape to 3D and add suffix .weight # Specific behavior for experts tensors: reshape to 3D and add suffix .weight
exp_tensor_names = ["ffn.experts.mlp.v1", "ffn.experts.mlp.w1", "ffn.experts.mlp.w2"] exp_tensor_names = {"ffn.experts.mlp.v1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_GATE_EXPS
"ffn.experts.mlp.w1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_DOWN_EXPS
"ffn.experts.mlp.w2": (n_ff, n_embd, n_expert)} # LLM_TENSOR_FFN_UP_EXPS
experts = False experts = False
for exp_tensor_name in exp_tensor_names: for exp_tensor_name in exp_tensor_names.keys():
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1: if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
experts = True experts = True
expert_reshape = exp_tensor_names[exp_tensor_name]
break break
old_dtype = data_torch.dtype old_dtype = data_torch.dtype
@ -1548,7 +1555,7 @@ class DbrxModel(Model):
# Reshape experts tensors from 2D to 3D as expected by GeLU # Reshape experts tensors from 2D to 3D as expected by GeLU
if experts and n_dims == 2: if experts and n_dims == 2:
data = data.reshape((self.hparams["ffn_config"]["moe_num_experts"], self.hparams["ffn_config"]["ffn_hidden_size"], self.hparams["d_model"])) data = data.reshape(expert_reshape)
n_dims = len(data.shape) n_dims = len(data.shape)
# if f32 desired, convert any float16 to float32 # if f32 desired, convert any float16 to float32