diff --git a/convert_grok.py b/convert_grok.py index 7d0dbfc77..4475da3c1 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -270,13 +270,9 @@ def convert_weight(name, weight, scales, config, dtype=torch.float32, device=Non else: weight = weight * scale - if name == "token_embd": - weight *= config.embedding_multiplier_scale - elif len(weight.shape) >= 2: + if name != "token_embd" and len(weight.shape) >= 2: # Transpose linear matrix weight = weight.transpose(-1, -2) - - if name.endswith("ffn_gate_inp") or name.endswith("_exps"): weight = weight[config.experts] # gather.