fix(bamba conv): Remove chunk size and consolidate head count w/ time step rank
head count and time step rank are used for the same purpose in the model, so we stick with the existing key. Chunk size is not used in this impl because of the way the mixer is implemented without chunking. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
3ee0ae3b90
commit
dfe8d3ddb8
3 changed files with 5 additions and 12 deletions
|
@ -3154,11 +3154,11 @@ class BambaModel(Mamba2Model):
|
||||||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
||||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
|
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
|
||||||
self.gguf_writer.add_ssm_group_count(self.n_group)
|
self.gguf_writer.add_ssm_group_count(self.n_group)
|
||||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["time_step_rank", "dt_rank"]))
|
|
||||||
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
||||||
self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"]))
|
|
||||||
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
|
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
|
||||||
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
|
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
|
||||||
|
# in llama.cpp
|
||||||
|
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
|
||||||
|
|
||||||
## Attention params ##
|
## Attention params ##
|
||||||
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
|
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
|
||||||
|
@ -3175,11 +3175,12 @@ class BambaModel(Mamba2Model):
|
||||||
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
|
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
|
||||||
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
|
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
|
||||||
|
|
||||||
## UNUSED?? ##
|
## UNUSED ##
|
||||||
# "tie_word_embeddings" <-- Implied by presence of output weights
|
# "tie_word_embeddings" <-- Implied by presence of output weights
|
||||||
# "num_logits_to_keep" <-- Always only keep final token logits
|
# "num_logits_to_keep" <-- Always only keep final token logits
|
||||||
# "use_cache" <-- KV Cache always enabled
|
# "use_cache" <-- KV Cache always enabled
|
||||||
# "use_mamba_kernels" <-- I think this will always be true if available?
|
# "use_mamba_kernels" <-- I think this will always be true if available?
|
||||||
|
# "chunk_size" <-- This is used in the mixer implementation in transformers, but not here
|
||||||
|
|
||||||
def modify_tensors(
|
def modify_tensors(
|
||||||
self, data_torch: Tensor, name: str, bid: int | None
|
self, data_torch: Tensor, name: str, bid: int | None
|
||||||
|
|
|
@ -151,9 +151,7 @@ class Keys:
|
||||||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||||
GROUP_COUNT = "{arch}.ssm.group_count"
|
GROUP_COUNT = "{arch}.ssm.group_count"
|
||||||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
||||||
HEAD_COUNT = "{arch}.ssm.head_count"
|
|
||||||
HEAD_DIM = "{arch}.ssm.head_dim"
|
HEAD_DIM = "{arch}.ssm.head_dim"
|
||||||
CHUNK_SIZE = "{arch}.ssm.chunk_size"
|
|
||||||
|
|
||||||
class HybridMamba:
|
class HybridMamba:
|
||||||
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
|
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
|
||||||
|
|
|
@ -790,15 +790,9 @@ class GGUFWriter:
|
||||||
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_ssm_head_count(self, value: int) -> None:
|
|
||||||
self.add_uint32(Keys.SSM.HEAD_COUNT.format(arch=self.arch), value)
|
|
||||||
|
|
||||||
def add_ssm_head_dim(self, value: int) -> None:
|
def add_ssm_head_dim(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)
|
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_ssm_chunk_size(self, value: int) -> None:
|
|
||||||
self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value)
|
|
||||||
|
|
||||||
def add_attn_layer_indices(self, values: list[int]) -> None:
|
def add_attn_layer_indices(self, values: list[int]) -> None:
|
||||||
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)
|
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue