convert: dbrx: fix remove wrong ATTN_OUT_NORM tensor, add output layer mapping

This commit is contained in:
Pierrick HYMBERT 2024-04-06 20:30:30 +02:00
parent c8e6f903e0
commit 916b91852b
2 changed files with 9 additions and 8 deletions

View file

@ -648,7 +648,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2, MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_DOWN_EXP,

View file

@ -52,6 +52,7 @@ class TensorNameMap:
"output", # llama-pth bloom internlm2 "output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon "word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2 "lm_head.linear", # phi2
"transformer.wte.weight", # dbrx
), ),
# Output norm # Output norm
@ -68,6 +69,7 @@ class TensorNameMap:
"model.norm_f", # mamba-qbert "model.norm_f", # mamba-qbert
"backbone.norm_f", # mamba "backbone.norm_f", # mamba
"transformer.rms_norm", # Grok "transformer.rms_norm", # Grok
"transformer.norm_f.weight", # dbrx
), ),
# Rope frequencies # Rope frequencies
@ -179,7 +181,6 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.LayerNorm", # bert "encoder.layer.{bid}.attention.output.LayerNorm", # bert
"encoder.layers.{bid}.norm1", # nomic-bert "encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok "transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj.weight", # dbrx
), ),
# Rotary embeddings # Rotary embeddings
@ -310,6 +311,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.LayerNorm", # bert "encoder.layer.{bid}.output.LayerNorm", # bert
"encoder.layers.{bid}.norm2", # nomic-bert "encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj.weight", # dbrx
), ),
MODEL_TENSOR.SSM_IN: ( MODEL_TENSOR.SSM_IN: (