fix(bamba conv): Fizes in tensor name and hparam conversion for llama.cpp parsing

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-04 12:00:46 -07:00
parent e0af809b05
commit fd3bb30118
4 changed files with 17 additions and 29 deletions

View file

@ -3094,8 +3094,7 @@ class Mamba2Model(Model):
# TODO: Switch to BambaForCausalLM once ready in transformers
# @Model.register("BambaForCausalLM")
@Model.register("JambaForCausalLM")
@Model.register("BambaForCausalLM")
class BambaModel(Mamba2Model):
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
model_arch = gguf.MODEL_ARCH.BAMBA
@ -3159,8 +3158,6 @@ class BambaModel(Mamba2Model):
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"]))
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False)
self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False)
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
## Attention params ##
@ -3187,6 +3184,7 @@ class BambaModel(Mamba2Model):
def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
# Determine whether this is a mamaba layer or an attention layer
if bid in self._ssm_layers:
for mamba_new_name, data_torch in super().modify_tensors(
@ -3199,13 +3197,11 @@ class BambaModel(Mamba2Model):
):
yield llama_new_name, data_torch
else:
yield name, data_torch
yield self.map_tensor_name(name), data_torch
def reshape_tensors(
self,
data_torch: Tensor,
new_name: str, bid: int | None,
self, data_torch: Tensor, new_name: str, bid: int | None,
) -> Tensor:
if bid in self._ssm_layers:
return super().reshape_tensors(data_torch, new_name, bid)

View file

@ -154,8 +154,6 @@ class Keys:
HEAD_COUNT = "{arch}.ssm.head_count"
HEAD_DIM = "{arch}.ssm.head_dim"
CHUNK_SIZE = "{arch}.ssm.chunk_size"
CONV_BIAS = "{arch}.ssm.conv_bias"
PROJ_BIAS = "{arch}.ssm.proj_bias"
class HybridMamba:
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"

View file

@ -799,12 +799,6 @@ class GGUFWriter:
def add_ssm_chunk_size(self, value: int) -> None:
self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value)
def add_ssm_conv_bias(self, value: bool) -> None:
self.add_bool(Keys.SSM.CONV_BIAS.format(arch=self.arch), value)
def add_ssm_proj_bias(self, value: bool) -> None:
self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value)
def add_attn_layer_indices(self, values: list[int]) -> None:
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)

View file

@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124 bamba
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
@ -101,7 +101,7 @@ class TensorNameMap:
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe bamba
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
@ -241,7 +241,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm.weight", # jamba
"model.layers.{bid}.pre_ff_layernorm", # bamba
),
# Post feed-forward norm
@ -294,7 +294,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # jamba
"model.layers.{bid}.feed_forward.up_proj", # bamba
),
MODEL_TENSOR.FFN_UP_EXP: (
@ -327,7 +327,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # jamba
"model.layers.{bid}.feed_forward.gate_proj", # bamba
),
MODEL_TENSOR.FFN_GATE_EXP: (
@ -368,7 +368,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # jamba
"model.layers.{bid}.feed_forward.down_proj", # bamba
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@ -417,13 +417,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj", # jamba
"model.layers.{bid}.mamba.in_proj", # bamba
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d", # jamba
"model.layers.{bid}.mamba.conv1d", # bamba
),
MODEL_TENSOR.SSM_X: (
@ -434,30 +434,30 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj", # jamba
"model.layers.{bid}.mamba.dt_proj", # bamba
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log", # jamba
"model.layers.{bid}.mamba.A_log", # bamba
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D", # jamba
"model.layers.{bid}.mamba.D", # bamba
),
MODEL_TENSOR.SSM_NORM: (
"backbone.layers.{bid}.mixer.norm", # mamba2
"model.layers.{bid}.mamba.norm", # jamba
"model.layers.{bid}.mamba.norm", # bamba
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.mamba.out_proj", # jamba
"model.layers.{bid}.mamba.out_proj", # bamba
),
MODEL_TENSOR.TIME_MIX_W1: (