fix(bamba conv): Jamba -> Bamba
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
e3525e9e50
commit
fd98682ec3
3 changed files with 21 additions and 48 deletions
|
@ -3093,12 +3093,12 @@ class Mamba2Model(Model):
|
||||||
return data_torch.squeeze()
|
return data_torch.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Switch to BambaForCausalLM once ready in transformers
|
||||||
|
# @Model.register("BambaForCausalLM")
|
||||||
@Model.register("JambaForCausalLM")
|
@Model.register("JambaForCausalLM")
|
||||||
class JambaModel(Model):
|
class BambaModel(Mamba2Model):
|
||||||
"""Jamba is a hybrid SSM + Attention model and can support either Mamba or
|
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
|
||||||
Mamba2 style SSMs
|
model_arch = gguf.MODEL_ARCH.BAMBA
|
||||||
"""
|
|
||||||
model_arch = gguf.MODEL_ARCH.JAMBA
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
||||||
|
@ -3108,17 +3108,7 @@ class JambaModel(Model):
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Determine if this is using Mamba or Mamba2
|
# Use Llama conversion for attention
|
||||||
self._mamba_version = self.hparams.get("mamba_version", "v1")
|
|
||||||
self._mamba_model_class: type[Model] = {
|
|
||||||
"v1": MambaModel,
|
|
||||||
"v2": Mamba2Model,
|
|
||||||
}.get(self._mamba_version, Model)
|
|
||||||
assert (
|
|
||||||
self._mamba_model_class is not Model
|
|
||||||
), f"Unsupported mamba_version: {self._mamba_version}"
|
|
||||||
|
|
||||||
# Use Llama conversion for attention / FF / MoE
|
|
||||||
self._transformer_model_class: type[Model] = LlamaModel
|
self._transformer_model_class: type[Model] = LlamaModel
|
||||||
|
|
||||||
# Lists of which layers use ssm vs attention
|
# Lists of which layers use ssm vs attention
|
||||||
|
@ -3152,17 +3142,14 @@ class JambaModel(Model):
|
||||||
keys = list(keys) + prefixed
|
keys = list(keys) + prefixed
|
||||||
return super().find_hparam(keys, *args, **kwargs)
|
return super().find_hparam(keys, *args, **kwargs)
|
||||||
|
|
||||||
def set_vocab(self):
|
|
||||||
self._mamba_model_class.set_vocab(self)
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
|
||||||
## General Params ##
|
## General Params ##
|
||||||
self.gguf_writer.add_embedding_length(self.d_model)
|
self.gguf_writer.add_embedding_length(self.d_model)
|
||||||
self.gguf_writer.add_mamba_version(self._mamba_version)
|
|
||||||
self.gguf_writer.add_block_count(self.block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
|
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
|
||||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
|
|
||||||
## Mamba mixer params ##
|
## Mamba mixer params ##
|
||||||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
||||||
|
@ -3175,8 +3162,6 @@ class JambaModel(Model):
|
||||||
self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False)
|
self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False)
|
||||||
self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False)
|
self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False)
|
||||||
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
|
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
|
||||||
# TODO: I think this will always be true if available?
|
|
||||||
# "use_mamba_kernels": true,
|
|
||||||
|
|
||||||
## Attention params ##
|
## Attention params ##
|
||||||
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
|
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
|
||||||
|
@ -3185,33 +3170,27 @@ class JambaModel(Model):
|
||||||
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
|
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
|
||||||
|
|
||||||
## Feed Forward Params ##
|
## Feed Forward Params ##
|
||||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
self.gguf_writer.add_layer_norm_rms_eps(
|
||||||
|
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||||
|
)
|
||||||
|
|
||||||
## Validation ##
|
## Validation ##
|
||||||
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
|
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
|
||||||
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
|
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
|
||||||
# TODO: Support MoE FFN configurations
|
|
||||||
# "num_experts"
|
|
||||||
# "num_experts_per_tok"
|
|
||||||
# "expert_layer_offset"
|
|
||||||
# "expert_layer_period"
|
|
||||||
assert self.hparams.get("num_experts") in [None, 1], "MoE not currently supported"
|
|
||||||
|
|
||||||
## UNUSED?? ##
|
## UNUSED?? ##
|
||||||
# "tie_word_embeddings" <-- Implied by presence of output weights
|
# "tie_word_embeddings" <-- Implied by presence of output weights
|
||||||
# "router_aux_loss_coef" <-- Only used if outputting router logits
|
|
||||||
# "num_logits_to_keep" <-- Always only keep final token logits
|
# "num_logits_to_keep" <-- Always only keep final token logits
|
||||||
# "output_router_logits" <-- Never output router logits since only doing generate
|
|
||||||
# "use_cache" <-- KV Cache always enabled
|
# "use_cache" <-- KV Cache always enabled
|
||||||
# "sliding_window" <-- Used for flash attention in transformers
|
# "use_mamba_kernels" <-- I think this will always be true if available?
|
||||||
|
|
||||||
def modify_tensors(
|
def modify_tensors(
|
||||||
self, data_torch: Tensor, name: str, bid: int | None
|
self, data_torch: Tensor, name: str, bid: int | None
|
||||||
) -> Iterable[tuple[str, Tensor]]:
|
) -> Iterable[tuple[str, Tensor]]:
|
||||||
# Determine whether this is a mamaba layer or an attention layer
|
# Determine whether this is a mamaba layer or an attention layer
|
||||||
if bid in self._ssm_layers:
|
if bid in self._ssm_layers:
|
||||||
for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors(
|
for mamba_new_name, data_torch in super().modify_tensors(
|
||||||
self, data_torch, name, bid
|
data_torch, name, bid
|
||||||
):
|
):
|
||||||
yield mamba_new_name, data_torch
|
yield mamba_new_name, data_torch
|
||||||
elif bid in self._attn_layers:
|
elif bid in self._attn_layers:
|
||||||
|
@ -3229,9 +3208,7 @@ class JambaModel(Model):
|
||||||
new_name: str, bid: int | None,
|
new_name: str, bid: int | None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if bid in self._ssm_layers:
|
if bid in self._ssm_layers:
|
||||||
return self._mamba_model_class.reshape_tensors(
|
return super().reshape_tensors(data_torch, new_name, bid)
|
||||||
self, data_torch, new_name, bid
|
|
||||||
)
|
|
||||||
elif bid in self._attn_layers:
|
elif bid in self._attn_layers:
|
||||||
return self._transformer_model_class.reshape_tensors(
|
return self._transformer_model_class.reshape_tensors(
|
||||||
self, data_torch, new_name, bid
|
self, data_torch, new_name, bid
|
||||||
|
|
|
@ -158,8 +158,7 @@ class Keys:
|
||||||
PROJ_BIAS = "{arch}.ssm.proj_bias"
|
PROJ_BIAS = "{arch}.ssm.proj_bias"
|
||||||
|
|
||||||
class HybridMamba:
|
class HybridMamba:
|
||||||
MAMBA_VERSION = "{arch}.mamba.version"
|
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
|
||||||
ATTN_LAYER_INDICES = "{arch}.attn.layers"
|
|
||||||
|
|
||||||
class WKV:
|
class WKV:
|
||||||
HEAD_SIZE = "{arch}.wkv.head_size"
|
HEAD_SIZE = "{arch}.wkv.head_size"
|
||||||
|
@ -250,7 +249,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
RWKV6 = auto()
|
RWKV6 = auto()
|
||||||
MAMBA = auto()
|
MAMBA = auto()
|
||||||
MAMBA2 = auto()
|
MAMBA2 = auto()
|
||||||
JAMBA = auto()
|
BAMBA = auto()
|
||||||
XVERSE = auto()
|
XVERSE = auto()
|
||||||
COMMAND_R = auto()
|
COMMAND_R = auto()
|
||||||
DBRX = auto()
|
DBRX = auto()
|
||||||
|
@ -415,7 +414,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.RWKV6: "rwkv6",
|
MODEL_ARCH.RWKV6: "rwkv6",
|
||||||
MODEL_ARCH.MAMBA: "mamba",
|
MODEL_ARCH.MAMBA: "mamba",
|
||||||
MODEL_ARCH.MAMBA2: "mamba2",
|
MODEL_ARCH.MAMBA2: "mamba2",
|
||||||
MODEL_ARCH.JAMBA: "jamba",
|
MODEL_ARCH.BAMBA: "bamba",
|
||||||
MODEL_ARCH.XVERSE: "xverse",
|
MODEL_ARCH.XVERSE: "xverse",
|
||||||
MODEL_ARCH.COMMAND_R: "command-r",
|
MODEL_ARCH.COMMAND_R: "command-r",
|
||||||
MODEL_ARCH.DBRX: "dbrx",
|
MODEL_ARCH.DBRX: "dbrx",
|
||||||
|
@ -1046,7 +1045,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.SSM_NORM,
|
MODEL_TENSOR.SSM_NORM,
|
||||||
MODEL_TENSOR.SSM_OUT,
|
MODEL_TENSOR.SSM_OUT,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.JAMBA: [
|
MODEL_ARCH.BAMBA: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
MODEL_TENSOR.OUTPUT,
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
|
|
@ -805,9 +805,6 @@ class GGUFWriter:
|
||||||
def add_ssm_proj_bias(self, value: bool) -> None:
|
def add_ssm_proj_bias(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value)
|
self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_mamba_version(self, value: str) -> None:
|
|
||||||
self.add_string(Keys.HybridMamba.MAMBA_VERSION.format(arch=self.arch), value)
|
|
||||||
|
|
||||||
def add_attn_layer_indices(self, values: list[int]) -> None:
|
def add_attn_layer_indices(self, values: list[int]) -> None:
|
||||||
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)
|
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue