model: dbrx: fix expert reshape

This commit is contained in:
Pierrick HYMBERT 2024-04-07 19:12:24 +02:00
parent dbfd59114f
commit 7dd84b0924

View file

@ -1548,7 +1548,7 @@ class DbrxModel(Model):
# Reshape experts tensors from 2D to 3D as expected by GeLU # Reshape experts tensors from 2D to 3D as expected by GeLU
if experts and n_dims == 2: if experts and n_dims == 2:
data = data.reshape((self.hparams["d_model"], self.hparams["ffn_config"]["ffn_hidden_size"], self.hparams["ffn_config"]["moe_num_experts"])) data = data.reshape((self.hparams["ffn_config"]["moe_num_experts"], self.hparams["ffn_config"]["ffn_hidden_size"], self.hparams["d_model"]))
n_dims = len(data.shape) n_dims = len(data.shape)
# if f32 desired, convert any float16 to float32 # if f32 desired, convert any float16 to float32