From 9a68f7537b39541afd771c96389ba3740ad8be4b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 26 Nov 2024 14:29:24 -0700 Subject: [PATCH] feat(jamba): First pass at GGUF conversion for Jamba models There are likely still some missing hparams, but the tensor mapping should be correct Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 73 +++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 27 +++++++++++++ gguf-py/gguf/tensor_mapping.py | 13 +++++- 3 files changed, 110 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e14ad9b01..fa67350ab 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -516,7 +516,10 @@ class Model: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + vocab_size = max( + self.hparams.get("vocab_size", len(tokenizer.vocab)), + len(tokenizer.vocab) + ) assert max(tokenizer.vocab.values()) < vocab_size tokpre = self.get_vocab_base_pre(tokenizer) @@ -3036,7 +3039,7 @@ class Mamba2Model(Model): # Fail early for models which don't have a block expansion factor of 2 # TODO: does this really matter? - assert d_inner == 2 * d_model + # assert d_inner == 2 * d_model assert d_inner % head_dim == 0 self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default @@ -3088,6 +3091,72 @@ class Mamba2Model(Model): return data_torch.squeeze() +@Model.register("JambaForCausalLM") +class JambaModel(Model): + """Jamba is a hybrid SSM + Attention model and can support either Mamba or + Mamba2 style SSMs + """ + model_arch = gguf.MODEL_ARCH.JAMBA + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Determine if this is using Mamba or Mamba2 + self._mamba_version = self.hparams.get("mamba_version", "v1") + self._mamba_model_class: type[Model] = { + "v1": MambaModel, + "v2": Mamba2Model, + }.get(self._mamba_version, Model) + assert ( + self._mamba_model_class is not Model + ), f"Unsupported mamba_version: {self._mamba_version}" + + # Use Llama conversion for attention / FF / MoE + self._transformer_model_class: type[Model] = LlamaModel + + # Lists of which layers use ssm vs attention + self._attn_layers = self.hparams.get("attn_layer_indices", []) + if not self._attn_layers: + attn_period = self.hparams.get("attn_layer_period") + assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" + attn_offset = self.hparams.get("attn_layer_offset") + assert attn_offset is not None, "No attention layer offset set with attn_layer_period" + self._attn_layers = [ + i for i in range(self.block_count) + if i % attn_period == attn_offset + ] + self._ssm_layers = [ + i for i in range(self.block_count) + if i not in self._attn_layers + ] + + def set_vocab(self): + self._mamba_model_class.set_vocab(self) + + def set_gguf_parameters(self): + # Set the mamba-type parameters + self._mamba_model_class.set_gguf_parameters(self) + + # TODO: All the rest! + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + # Determine whether this is a mamaba layer or an attention layer + if bid in self._ssm_layers: + for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors( + self, data_torch, name, bid + ): + yield mamba_new_name, data_torch + elif bid in self._attn_layers: + for llama_new_name, data_torch in self._transformer_model_class.modify_tensors( + self, data_torch, name, bid + ): + yield llama_new_name, data_torch + else: + yield name, data_torch + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c7bd9acd9..4739024f8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -241,6 +241,7 @@ class MODEL_ARCH(IntEnum): RWKV6 = auto() MAMBA = auto() MAMBA2 = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -405,6 +406,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA2: "mamba2", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -1035,6 +1037,31 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index fe4bfa3d0..7e126d9bf 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -101,7 +101,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -241,6 +241,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.pre_ff_layernorm.weight", # jamba ), # Post feed-forward norm @@ -293,6 +294,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone + "model.layers.{bid}.feed_forward.up_proj", # jamba ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -325,6 +327,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone + "model.layers.{bid}.feed_forward.gate_proj", # jamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -365,6 +368,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone + "model.layers.{bid}.feed_forward.down_proj", # jamba ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -413,11 +417,13 @@ class TensorNameMap: MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.mamba.in_proj", # jamba ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.mamba.conv1d", # jamba ), MODEL_TENSOR.SSM_X: ( @@ -428,25 +434,30 @@ class TensorNameMap: MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.mamba.dt_proj", # jamba ), MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.mamba.A_log", # jamba ), MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.mamba.D", # jamba ), MODEL_TENSOR.SSM_NORM: ( "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.mamba.norm", # jamba ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.mamba.out_proj", # jamba ), MODEL_TENSOR.TIME_MIX_W1: (