fix(bamba conv): Fizes in tensor name and hparam conversion for llama.cpp parsing

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-04 12:00:46 -07:00
parent e0af809b05
commit fd3bb30118
4 changed files with 17 additions and 29 deletions

View file

@ -3094,8 +3094,7 @@ class Mamba2Model(Model):
# TODO: Switch to BambaForCausalLM once ready in transformers # TODO: Switch to BambaForCausalLM once ready in transformers
# @Model.register("BambaForCausalLM") @Model.register("BambaForCausalLM")
@Model.register("JambaForCausalLM")
class BambaModel(Mamba2Model): class BambaModel(Mamba2Model):
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers""" """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
model_arch = gguf.MODEL_ARCH.BAMBA model_arch = gguf.MODEL_ARCH.BAMBA
@ -3159,8 +3158,6 @@ class BambaModel(Mamba2Model):
self.gguf_writer.add_ssm_inner_size(self.d_inner) self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"])) self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"]))
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False)
self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False)
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"])) self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
## Attention params ## ## Attention params ##
@ -3187,6 +3184,7 @@ class BambaModel(Mamba2Model):
def modify_tensors( def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]: ) -> Iterable[tuple[str, Tensor]]:
# Determine whether this is a mamaba layer or an attention layer # Determine whether this is a mamaba layer or an attention layer
if bid in self._ssm_layers: if bid in self._ssm_layers:
for mamba_new_name, data_torch in super().modify_tensors( for mamba_new_name, data_torch in super().modify_tensors(
@ -3199,13 +3197,11 @@ class BambaModel(Mamba2Model):
): ):
yield llama_new_name, data_torch yield llama_new_name, data_torch
else: else:
yield name, data_torch yield self.map_tensor_name(name), data_torch
def reshape_tensors( def reshape_tensors(
self, self, data_torch: Tensor, new_name: str, bid: int | None,
data_torch: Tensor,
new_name: str, bid: int | None,
) -> Tensor: ) -> Tensor:
if bid in self._ssm_layers: if bid in self._ssm_layers:
return super().reshape_tensors(data_torch, new_name, bid) return super().reshape_tensors(data_torch, new_name, bid)

View file

@ -154,8 +154,6 @@ class Keys:
HEAD_COUNT = "{arch}.ssm.head_count" HEAD_COUNT = "{arch}.ssm.head_count"
HEAD_DIM = "{arch}.ssm.head_dim" HEAD_DIM = "{arch}.ssm.head_dim"
CHUNK_SIZE = "{arch}.ssm.chunk_size" CHUNK_SIZE = "{arch}.ssm.chunk_size"
CONV_BIAS = "{arch}.ssm.conv_bias"
PROJ_BIAS = "{arch}.ssm.proj_bias"
class HybridMamba: class HybridMamba:
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices" ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"

View file

@ -799,12 +799,6 @@ class GGUFWriter:
def add_ssm_chunk_size(self, value: int) -> None: def add_ssm_chunk_size(self, value: int) -> None:
self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value) self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value)
def add_ssm_conv_bias(self, value: bool) -> None:
self.add_bool(Keys.SSM.CONV_BIAS.format(arch=self.arch), value)
def add_ssm_proj_bias(self, value: bool) -> None:
self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value)
def add_attn_layer_indices(self, values: list[int]) -> None: def add_attn_layer_indices(self, values: list[int]) -> None:
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values) self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)

View file

@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon "transformer.word_embeddings", # falcon
"word_embeddings", # bloom "word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124 "model.embed_tokens", # llama-hf nemotron olmoe olmo_1124 bamba
"tok_embeddings", # llama-pth "tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert "embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon "language_model.embedding.word_embeddings", # persimmon
@ -101,7 +101,7 @@ class TensorNameMap:
"transformer.h.{bid}.input_layernorm", # falcon7b "transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom "h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b "transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe bamba
"layers.{bid}.attention_norm", # llama-pth "layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi "model.layers.{bid}.ln1", # yi
@ -241,7 +241,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok "transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm "encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm "transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm.weight", # jamba "model.layers.{bid}.pre_ff_layernorm", # bamba
), ),
# Post feed-forward norm # Post feed-forward norm
@ -294,7 +294,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w3", # arctic "model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone "transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # jamba "model.layers.{bid}.feed_forward.up_proj", # bamba
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -327,7 +327,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.linear_1", # refact "transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic "model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone "transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # jamba "model.layers.{bid}.feed_forward.gate_proj", # bamba
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
@ -368,7 +368,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # jamba "model.layers.{bid}.feed_forward.down_proj", # bamba
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (
@ -417,13 +417,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: ( MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj", "model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj", "backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj", # jamba "model.layers.{bid}.mamba.in_proj", # bamba
), ),
MODEL_TENSOR.SSM_CONV1D: ( MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d", "model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d", "backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d", # jamba "model.layers.{bid}.mamba.conv1d", # bamba
), ),
MODEL_TENSOR.SSM_X: ( MODEL_TENSOR.SSM_X: (
@ -434,30 +434,30 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: ( MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj", "model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj", "backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj", # jamba "model.layers.{bid}.mamba.dt_proj", # bamba
), ),
MODEL_TENSOR.SSM_A: ( MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log", "model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log", "backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log", # jamba "model.layers.{bid}.mamba.A_log", # bamba
), ),
MODEL_TENSOR.SSM_D: ( MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D", "model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D", "backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D", # jamba "model.layers.{bid}.mamba.D", # bamba
), ),
MODEL_TENSOR.SSM_NORM: ( MODEL_TENSOR.SSM_NORM: (
"backbone.layers.{bid}.mixer.norm", # mamba2 "backbone.layers.{bid}.mixer.norm", # mamba2
"model.layers.{bid}.mamba.norm", # jamba "model.layers.{bid}.mamba.norm", # bamba
), ),
MODEL_TENSOR.SSM_OUT: ( MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", "model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.mamba.out_proj", # jamba "model.layers.{bid}.mamba.out_proj", # bamba
), ),
MODEL_TENSOR.TIME_MIX_W1: ( MODEL_TENSOR.TIME_MIX_W1: (