diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index fa67350ab..da4526ee6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3007,6 +3007,14 @@ class MambaModel(Model): class Mamba2Model(Model): model_arch = gguf.MODEL_ARCH.MAMBA2 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # n_groups and d_inner are used during reshaping + self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + self.n_group = self.find_hparam(["n_groups"], optional=True) or 1 + self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 16 @@ -3028,30 +3036,27 @@ class Mamba2Model(Model): self._set_vocab_builtin("gpt-neox", vocab_size) def set_gguf_parameters(self): - d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) - d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model - d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 - head_dim = self.find_hparam(["head_dim"], optional=True) or 64 - n_group = self.find_hparam(["n_groups"], optional=True) or 1 + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 # Fail early for models which don't have a block expansion factor of 2 # TODO: does this really matter? - # assert d_inner == 2 * d_model - assert d_inner % head_dim == 0 + assert self.d_inner == 2 * self.d_model + assert self.d_inner % head_dim == 0 self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default - self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_embedding_length(self.d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) - self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_inner_size(self.d_inner) self.gguf_writer.add_ssm_state_size(d_state) - self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) - self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_ssm_time_step_rank(self.d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(self.n_group) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) self.gguf_writer.add_file_type(self.ftype) @@ -3083,10 +3088,7 @@ class Mamba2Model(Model): return data_torch.reshape((*data_torch.shape, 1)) elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): - d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model - n_group = self.hparams.get("n_groups", 1) - return data_torch.reshape((n_group, d_inner // n_group)) + return data_torch.reshape((self.n_group, self.d_inner // self.n_group)) return data_torch.squeeze() @@ -3099,6 +3101,11 @@ class JambaModel(Model): model_arch = gguf.MODEL_ARCH.JAMBA def __init__(self, *args, **kwargs): + + # Hybrid mamba models use a prefix for the mamba-specific params. + # TODO: Extend this if the prefix(es) need to be configurable + self.hparam_prefixes = ["mamba"] + super().__init__(*args, **kwargs) # Determine if this is using Mamba or Mamba2 @@ -3130,14 +3137,73 @@ class JambaModel(Model): if i not in self._attn_layers ] + # n_group and d_inner are used during reshape_tensors for mamaba2 + self.d_model = self.find_hparam(["hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups"]) + self.d_inner = self.find_hparam(["expand"]) * self.d_model + + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: + prefixed = [] + for pfx in self.hparam_prefixes: + prefixed.extend( + "_".join([pfx, k]) + for k in keys + ) + keys = list(keys) + prefixed + return super().find_hparam(keys, *args, **kwargs) + def set_vocab(self): self._mamba_model_class.set_vocab(self) def set_gguf_parameters(self): - # Set the mamba-type parameters - self._mamba_model_class.set_gguf_parameters(self) - # TODO: All the rest! + ## General Params ## + self.gguf_writer.add_embedding_length(self.d_model) + self.gguf_writer.add_mamba_version(self._mamba_version) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + ## Mamba mixer params ## + self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) + self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) + self.gguf_writer.add_ssm_group_count(self.n_group) + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["time_step_rank", "dt_rank"])) + self.gguf_writer.add_ssm_inner_size(self.d_inner) + self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"])) + self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) + self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False) + self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False) + self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"])) + # TODO: I think this will always be true if available? + # "use_mamba_kernels": true, + + ## Attention params ## + self.gguf_writer.add_attn_layer_indices(self._attn_layers) + self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) + + ## Feed Forward Params ## + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + ## Validation ## + assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" + assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" + # TODO: Support MoE FFN configurations + # "num_experts" + # "num_experts_per_tok" + # "expert_layer_offset" + # "expert_layer_period" + assert self.hparams.get("num_experts") in [None, 1], "MoE not currently supported" + + ## UNUSED?? ## + # "tie_word_embeddings" <-- Implied by presence of output weights + # "router_aux_loss_coef" <-- Only used if outputting router logits + # "num_logits_to_keep" <-- Always only keep final token logits + # "output_router_logits" <-- Never output router logits since only doing generate + # "use_cache" <-- KV Cache always enabled + # "sliding_window" <-- Used for flash attention in transformers def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None @@ -3157,6 +3223,22 @@ class JambaModel(Model): yield name, data_torch + def reshape_tensors( + self, + data_torch: Tensor, + new_name: str, bid: int | None, + ) -> Tensor: + if bid in self._ssm_layers: + return self._mamba_model_class.reshape_tensors( + self, data_torch, new_name, bid + ) + elif bid in self._attn_layers: + return self._transformer_model_class.reshape_tensors( + self, data_torch, new_name, bid + ) + return data_torch + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4739024f8..3e5c92863 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -151,6 +151,15 @@ class Keys: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + HEAD_COUNT = "{arch}.ssm.head_count" + HEAD_DIM = "{arch}.ssm.head_dim" + CHUNK_SIZE = "{arch}.ssm.chunk_size" + CONV_BIAS = "{arch}.ssm.conv_bias" + PROJ_BIAS = "{arch}.ssm.proj_bias" + + class HybridMamba: + MAMBA_VERSION = "{arch}.mamba.version" + ATTN_LAYER_INDICES = "{arch}.attn.layers" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 084dd25e2..7b0126ce1 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -790,6 +790,27 @@ class GGUFWriter: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_ssm_head_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.HEAD_COUNT.format(arch=self.arch), value) + + def add_ssm_head_dim(self, value: int) -> None: + self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value) + + def add_ssm_chunk_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value) + + def add_ssm_conv_bias(self, value: bool) -> None: + self.add_bool(Keys.SSM.CONV_BIAS.format(arch=self.arch), value) + + def add_ssm_proj_bias(self, value: bool) -> None: + self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value) + + def add_mamba_version(self, value: str) -> None: + self.add_string(Keys.HybridMamba.MAMBA_VERSION.format(arch=self.arch), value) + + def add_attn_layer_indices(self, values: list[int]) -> None: + self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model)