diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index da4526ee6..13f3b570a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3093,12 +3093,12 @@ class Mamba2Model(Model): return data_torch.squeeze() +# TODO: Switch to BambaForCausalLM once ready in transformers +# @Model.register("BambaForCausalLM") @Model.register("JambaForCausalLM") -class JambaModel(Model): - """Jamba is a hybrid SSM + Attention model and can support either Mamba or - Mamba2 style SSMs - """ - model_arch = gguf.MODEL_ARCH.JAMBA +class BambaModel(Mamba2Model): + """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers""" + model_arch = gguf.MODEL_ARCH.BAMBA def __init__(self, *args, **kwargs): @@ -3108,17 +3108,7 @@ class JambaModel(Model): super().__init__(*args, **kwargs) - # Determine if this is using Mamba or Mamba2 - self._mamba_version = self.hparams.get("mamba_version", "v1") - self._mamba_model_class: type[Model] = { - "v1": MambaModel, - "v2": Mamba2Model, - }.get(self._mamba_version, Model) - assert ( - self._mamba_model_class is not Model - ), f"Unsupported mamba_version: {self._mamba_version}" - - # Use Llama conversion for attention / FF / MoE + # Use Llama conversion for attention self._transformer_model_class: type[Model] = LlamaModel # Lists of which layers use ssm vs attention @@ -3152,17 +3142,14 @@ class JambaModel(Model): keys = list(keys) + prefixed return super().find_hparam(keys, *args, **kwargs) - def set_vocab(self): - self._mamba_model_class.set_vocab(self) - def set_gguf_parameters(self): ## General Params ## self.gguf_writer.add_embedding_length(self.d_model) - self.gguf_writer.add_mamba_version(self._mamba_version) self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) ## Mamba mixer params ## self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) @@ -3175,8 +3162,6 @@ class JambaModel(Model): self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False) self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False) self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"])) - # TODO: I think this will always be true if available? - # "use_mamba_kernels": true, ## Attention params ## self.gguf_writer.add_attn_layer_indices(self._attn_layers) @@ -3185,33 +3170,27 @@ class JambaModel(Model): self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) ## Feed Forward Params ## - rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + self.gguf_writer.add_layer_norm_rms_eps( + self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + ) ## Validation ## assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" - # TODO: Support MoE FFN configurations - # "num_experts" - # "num_experts_per_tok" - # "expert_layer_offset" - # "expert_layer_period" - assert self.hparams.get("num_experts") in [None, 1], "MoE not currently supported" ## UNUSED?? ## - # "tie_word_embeddings" <-- Implied by presence of output weights - # "router_aux_loss_coef" <-- Only used if outputting router logits - # "num_logits_to_keep" <-- Always only keep final token logits - # "output_router_logits" <-- Never output router logits since only doing generate - # "use_cache" <-- KV Cache always enabled - # "sliding_window" <-- Used for flash attention in transformers + # "tie_word_embeddings" <-- Implied by presence of output weights + # "num_logits_to_keep" <-- Always only keep final token logits + # "use_cache" <-- KV Cache always enabled + # "use_mamba_kernels" <-- I think this will always be true if available? def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None ) -> Iterable[tuple[str, Tensor]]: # Determine whether this is a mamaba layer or an attention layer if bid in self._ssm_layers: - for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors( - self, data_torch, name, bid + for mamba_new_name, data_torch in super().modify_tensors( + data_torch, name, bid ): yield mamba_new_name, data_torch elif bid in self._attn_layers: @@ -3229,9 +3208,7 @@ class JambaModel(Model): new_name: str, bid: int | None, ) -> Tensor: if bid in self._ssm_layers: - return self._mamba_model_class.reshape_tensors( - self, data_torch, new_name, bid - ) + return super().reshape_tensors(data_torch, new_name, bid) elif bid in self._attn_layers: return self._transformer_model_class.reshape_tensors( self, data_torch, new_name, bid diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3e5c92863..c2d363309 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -158,8 +158,7 @@ class Keys: PROJ_BIAS = "{arch}.ssm.proj_bias" class HybridMamba: - MAMBA_VERSION = "{arch}.mamba.version" - ATTN_LAYER_INDICES = "{arch}.attn.layers" + ATTN_LAYER_INDICES = "{arch}.attention.layer_indices" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" @@ -250,7 +249,7 @@ class MODEL_ARCH(IntEnum): RWKV6 = auto() MAMBA = auto() MAMBA2 = auto() - JAMBA = auto() + BAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -415,7 +414,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA2: "mamba2", - MODEL_ARCH.JAMBA: "jamba", + MODEL_ARCH.BAMBA: "bamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -1046,7 +1045,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, ], - MODEL_ARCH.JAMBA: [ + MODEL_ARCH.BAMBA: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7b0126ce1..399887b2e 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -805,9 +805,6 @@ class GGUFWriter: def add_ssm_proj_bias(self, value: bool) -> None: self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value) - def add_mamba_version(self, value: str) -> None: - self.add_string(Keys.HybridMamba.MAMBA_VERSION.format(arch=self.arch), value) - def add_attn_layer_indices(self, values: list[int]) -> None: self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)