feat(jamba): First pass at GGUF conversion for Jamba models
There are likely still some missing hparams, but the tensor mapping should be correct Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
1ee6c482d0
commit
9a68f7537b
3 changed files with 110 additions and 3 deletions
|
@ -516,7 +516,10 @@ class Model:
|
|||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
vocab_size = max(
|
||||
self.hparams.get("vocab_size", len(tokenizer.vocab)),
|
||||
len(tokenizer.vocab)
|
||||
)
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
@ -3036,7 +3039,7 @@ class Mamba2Model(Model):
|
|||
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
# TODO: does this really matter?
|
||||
assert d_inner == 2 * d_model
|
||||
# assert d_inner == 2 * d_model
|
||||
assert d_inner % head_dim == 0
|
||||
|
||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||
|
@ -3088,6 +3091,72 @@ class Mamba2Model(Model):
|
|||
return data_torch.squeeze()
|
||||
|
||||
|
||||
@Model.register("JambaForCausalLM")
|
||||
class JambaModel(Model):
|
||||
"""Jamba is a hybrid SSM + Attention model and can support either Mamba or
|
||||
Mamba2 style SSMs
|
||||
"""
|
||||
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Determine if this is using Mamba or Mamba2
|
||||
self._mamba_version = self.hparams.get("mamba_version", "v1")
|
||||
self._mamba_model_class: type[Model] = {
|
||||
"v1": MambaModel,
|
||||
"v2": Mamba2Model,
|
||||
}.get(self._mamba_version, Model)
|
||||
assert (
|
||||
self._mamba_model_class is not Model
|
||||
), f"Unsupported mamba_version: {self._mamba_version}"
|
||||
|
||||
# Use Llama conversion for attention / FF / MoE
|
||||
self._transformer_model_class: type[Model] = LlamaModel
|
||||
|
||||
# Lists of which layers use ssm vs attention
|
||||
self._attn_layers = self.hparams.get("attn_layer_indices", [])
|
||||
if not self._attn_layers:
|
||||
attn_period = self.hparams.get("attn_layer_period")
|
||||
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
|
||||
attn_offset = self.hparams.get("attn_layer_offset")
|
||||
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
|
||||
self._attn_layers = [
|
||||
i for i in range(self.block_count)
|
||||
if i % attn_period == attn_offset
|
||||
]
|
||||
self._ssm_layers = [
|
||||
i for i in range(self.block_count)
|
||||
if i not in self._attn_layers
|
||||
]
|
||||
|
||||
def set_vocab(self):
|
||||
self._mamba_model_class.set_vocab(self)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# Set the mamba-type parameters
|
||||
self._mamba_model_class.set_gguf_parameters(self)
|
||||
|
||||
# TODO: All the rest!
|
||||
|
||||
def modify_tensors(
|
||||
self, data_torch: Tensor, name: str, bid: int | None
|
||||
) -> Iterable[tuple[str, Tensor]]:
|
||||
# Determine whether this is a mamaba layer or an attention layer
|
||||
if bid in self._ssm_layers:
|
||||
for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors(
|
||||
self, data_torch, name, bid
|
||||
):
|
||||
yield mamba_new_name, data_torch
|
||||
elif bid in self._attn_layers:
|
||||
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
|
||||
self, data_torch, name, bid
|
||||
):
|
||||
yield llama_new_name, data_torch
|
||||
else:
|
||||
yield name, data_torch
|
||||
|
||||
|
||||
@Model.register("CohereForCausalLM")
|
||||
class CommandR2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
||||
|
|
|
@ -241,6 +241,7 @@ class MODEL_ARCH(IntEnum):
|
|||
RWKV6 = auto()
|
||||
MAMBA = auto()
|
||||
MAMBA2 = auto()
|
||||
JAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
|
@ -405,6 +406,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.MAMBA2: "mamba2",
|
||||
MODEL_ARCH.JAMBA: "jamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
|
@ -1035,6 +1037,31 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
],
|
||||
MODEL_ARCH.JAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.XVERSE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
@ -101,7 +101,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
|
@ -241,6 +241,7 @@ class TensorNameMap:
|
|||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight", # jamba
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
|
@ -293,6 +294,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"model.layers.{bid}.feed_forward.up_proj", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
|
@ -325,6 +327,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
|
@ -365,6 +368,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"model.layers.{bid}.feed_forward.down_proj", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
|
@ -413,11 +417,13 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj",
|
||||
"backbone.layers.{bid}.mixer.in_proj",
|
||||
"model.layers.{bid}.mamba.in_proj", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
"model.layers.{bid}.conv1d",
|
||||
"backbone.layers.{bid}.mixer.conv1d",
|
||||
"model.layers.{bid}.mamba.conv1d", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
|
@ -428,25 +434,30 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.SSM_DT: (
|
||||
"model.layers.{bid}.dt_proj",
|
||||
"backbone.layers.{bid}.mixer.dt_proj",
|
||||
"model.layers.{bid}.mamba.dt_proj", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_A: (
|
||||
"model.layers.{bid}.A_log",
|
||||
"backbone.layers.{bid}.mixer.A_log",
|
||||
"model.layers.{bid}.mamba.A_log", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_D: (
|
||||
"model.layers.{bid}.D",
|
||||
"backbone.layers.{bid}.mixer.D",
|
||||
"model.layers.{bid}.mamba.D", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_NORM: (
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
"model.layers.{bid}.mamba.norm", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
"model.layers.{bid}.mamba.out_proj", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue