fix(bamba conv): Remove chunk size and consolidate head count w/ time step rank

head count and time step rank are used for the same purpose in the model,
so we stick with the existing key. Chunk size is not used in this impl
because of the way the mixer is implemented without chunking.

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-05 10:59:21 -07:00
parent 3ee0ae3b90
commit dfe8d3ddb8
3 changed files with 5 additions and 12 deletions

View file

@ -3154,11 +3154,11 @@ class BambaModel(Mamba2Model):
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["time_step_rank", "dt_rank"]))
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"]))
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
## Attention params ##
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
@ -3175,11 +3175,12 @@ class BambaModel(Mamba2Model):
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
## UNUSED?? ##
## UNUSED ##
# "tie_word_embeddings" <-- Implied by presence of output weights
# "num_logits_to_keep" <-- Always only keep final token logits
# "use_cache" <-- KV Cache always enabled
# "use_mamba_kernels" <-- I think this will always be true if available?
# "chunk_size" <-- This is used in the mixer implementation in transformers, but not here
def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None

View file

@ -151,9 +151,7 @@ class Keys:
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
HEAD_COUNT = "{arch}.ssm.head_count"
HEAD_DIM = "{arch}.ssm.head_dim"
CHUNK_SIZE = "{arch}.ssm.chunk_size"
class HybridMamba:
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"

View file

@ -790,15 +790,9 @@ class GGUFWriter:
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
def add_ssm_head_count(self, value: int) -> None:
self.add_uint32(Keys.SSM.HEAD_COUNT.format(arch=self.arch), value)
def add_ssm_head_dim(self, value: int) -> None:
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)
def add_ssm_chunk_size(self, value: int) -> None:
self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value)
def add_attn_layer_indices(self, values: list[int]) -> None:
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)