fix(bamba conv): Fizes in tensor name and hparam conversion for llama.cpp parsing
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
e0af809b05
commit
fd3bb30118
4 changed files with 17 additions and 29 deletions
|
@ -3094,8 +3094,7 @@ class Mamba2Model(Model):
|
|||
|
||||
|
||||
# TODO: Switch to BambaForCausalLM once ready in transformers
|
||||
# @Model.register("BambaForCausalLM")
|
||||
@Model.register("JambaForCausalLM")
|
||||
@Model.register("BambaForCausalLM")
|
||||
class BambaModel(Mamba2Model):
|
||||
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
|
||||
model_arch = gguf.MODEL_ARCH.BAMBA
|
||||
|
@ -3159,8 +3158,6 @@ class BambaModel(Mamba2Model):
|
|||
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
||||
self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"]))
|
||||
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
|
||||
self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False)
|
||||
self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False)
|
||||
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
|
||||
|
||||
## Attention params ##
|
||||
|
@ -3187,6 +3184,7 @@ class BambaModel(Mamba2Model):
|
|||
def modify_tensors(
|
||||
self, data_torch: Tensor, name: str, bid: int | None
|
||||
) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
# Determine whether this is a mamaba layer or an attention layer
|
||||
if bid in self._ssm_layers:
|
||||
for mamba_new_name, data_torch in super().modify_tensors(
|
||||
|
@ -3199,13 +3197,11 @@ class BambaModel(Mamba2Model):
|
|||
):
|
||||
yield llama_new_name, data_torch
|
||||
else:
|
||||
yield name, data_torch
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
def reshape_tensors(
|
||||
self,
|
||||
data_torch: Tensor,
|
||||
new_name: str, bid: int | None,
|
||||
self, data_torch: Tensor, new_name: str, bid: int | None,
|
||||
) -> Tensor:
|
||||
if bid in self._ssm_layers:
|
||||
return super().reshape_tensors(data_torch, new_name, bid)
|
||||
|
|
|
@ -154,8 +154,6 @@ class Keys:
|
|||
HEAD_COUNT = "{arch}.ssm.head_count"
|
||||
HEAD_DIM = "{arch}.ssm.head_dim"
|
||||
CHUNK_SIZE = "{arch}.ssm.chunk_size"
|
||||
CONV_BIAS = "{arch}.ssm.conv_bias"
|
||||
PROJ_BIAS = "{arch}.ssm.proj_bias"
|
||||
|
||||
class HybridMamba:
|
||||
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
|
||||
|
|
|
@ -799,12 +799,6 @@ class GGUFWriter:
|
|||
def add_ssm_chunk_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_conv_bias(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.CONV_BIAS.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_proj_bias(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value)
|
||||
|
||||
def add_attn_layer_indices(self, values: list[int]) -> None:
|
||||
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ class TensorNameMap:
|
|||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124 bamba
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
|
@ -101,7 +101,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe bamba
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
|
@ -241,7 +241,7 @@ class TensorNameMap:
|
|||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight", # jamba
|
||||
"model.layers.{bid}.pre_ff_layernorm", # bamba
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
|
@ -294,7 +294,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"model.layers.{bid}.feed_forward.up_proj", # jamba
|
||||
"model.layers.{bid}.feed_forward.up_proj", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
|
@ -327,7 +327,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # jamba
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
|
@ -368,7 +368,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"model.layers.{bid}.feed_forward.down_proj", # jamba
|
||||
"model.layers.{bid}.feed_forward.down_proj", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
|
@ -417,13 +417,13 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj",
|
||||
"backbone.layers.{bid}.mixer.in_proj",
|
||||
"model.layers.{bid}.mamba.in_proj", # jamba
|
||||
"model.layers.{bid}.mamba.in_proj", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
"model.layers.{bid}.conv1d",
|
||||
"backbone.layers.{bid}.mixer.conv1d",
|
||||
"model.layers.{bid}.mamba.conv1d", # jamba
|
||||
"model.layers.{bid}.mamba.conv1d", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
|
@ -434,30 +434,30 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.SSM_DT: (
|
||||
"model.layers.{bid}.dt_proj",
|
||||
"backbone.layers.{bid}.mixer.dt_proj",
|
||||
"model.layers.{bid}.mamba.dt_proj", # jamba
|
||||
"model.layers.{bid}.mamba.dt_proj", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_A: (
|
||||
"model.layers.{bid}.A_log",
|
||||
"backbone.layers.{bid}.mixer.A_log",
|
||||
"model.layers.{bid}.mamba.A_log", # jamba
|
||||
"model.layers.{bid}.mamba.A_log", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_D: (
|
||||
"model.layers.{bid}.D",
|
||||
"backbone.layers.{bid}.mixer.D",
|
||||
"model.layers.{bid}.mamba.D", # jamba
|
||||
"model.layers.{bid}.mamba.D", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_NORM: (
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
"model.layers.{bid}.mamba.norm", # jamba
|
||||
"model.layers.{bid}.mamba.norm", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
"model.layers.{bid}.mamba.out_proj", # jamba
|
||||
"model.layers.{bid}.mamba.out_proj", # bamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue