model: dbrx: fix expert reshape
This commit is contained in:
parent
dbfd59114f
commit
7dd84b0924
1 changed files with 1 additions and 1 deletions
|
@ -1548,7 +1548,7 @@ class DbrxModel(Model):
|
|||
|
||||
# Reshape experts tensors from 2D to 3D as expected by GeLU
|
||||
if experts and n_dims == 2:
|
||||
data = data.reshape((self.hparams["d_model"], self.hparams["ffn_config"]["ffn_hidden_size"], self.hparams["ffn_config"]["moe_num_experts"]))
|
||||
data = data.reshape((self.hparams["ffn_config"]["moe_num_experts"], self.hparams["ffn_config"]["ffn_hidden_size"], self.hparams["d_model"]))
|
||||
n_dims = len(data.shape)
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue