feat(jamba): First pass at GGUF conversion for Jamba models
There are likely still some missing hparams, but the tensor mapping should be correct Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
1ee6c482d0
commit
9a68f7537b
3 changed files with 110 additions and 3 deletions
|
@ -516,7 +516,10 @@ class Model:
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
vocab_size = max(
|
||||||
|
self.hparams.get("vocab_size", len(tokenizer.vocab)),
|
||||||
|
len(tokenizer.vocab)
|
||||||
|
)
|
||||||
assert max(tokenizer.vocab.values()) < vocab_size
|
assert max(tokenizer.vocab.values()) < vocab_size
|
||||||
|
|
||||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||||
|
@ -3036,7 +3039,7 @@ class Mamba2Model(Model):
|
||||||
|
|
||||||
# Fail early for models which don't have a block expansion factor of 2
|
# Fail early for models which don't have a block expansion factor of 2
|
||||||
# TODO: does this really matter?
|
# TODO: does this really matter?
|
||||||
assert d_inner == 2 * d_model
|
# assert d_inner == 2 * d_model
|
||||||
assert d_inner % head_dim == 0
|
assert d_inner % head_dim == 0
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||||
|
@ -3088,6 +3091,72 @@ class Mamba2Model(Model):
|
||||||
return data_torch.squeeze()
|
return data_torch.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("JambaForCausalLM")
|
||||||
|
class JambaModel(Model):
|
||||||
|
"""Jamba is a hybrid SSM + Attention model and can support either Mamba or
|
||||||
|
Mamba2 style SSMs
|
||||||
|
"""
|
||||||
|
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Determine if this is using Mamba or Mamba2
|
||||||
|
self._mamba_version = self.hparams.get("mamba_version", "v1")
|
||||||
|
self._mamba_model_class: type[Model] = {
|
||||||
|
"v1": MambaModel,
|
||||||
|
"v2": Mamba2Model,
|
||||||
|
}.get(self._mamba_version, Model)
|
||||||
|
assert (
|
||||||
|
self._mamba_model_class is not Model
|
||||||
|
), f"Unsupported mamba_version: {self._mamba_version}"
|
||||||
|
|
||||||
|
# Use Llama conversion for attention / FF / MoE
|
||||||
|
self._transformer_model_class: type[Model] = LlamaModel
|
||||||
|
|
||||||
|
# Lists of which layers use ssm vs attention
|
||||||
|
self._attn_layers = self.hparams.get("attn_layer_indices", [])
|
||||||
|
if not self._attn_layers:
|
||||||
|
attn_period = self.hparams.get("attn_layer_period")
|
||||||
|
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
|
||||||
|
attn_offset = self.hparams.get("attn_layer_offset")
|
||||||
|
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
|
||||||
|
self._attn_layers = [
|
||||||
|
i for i in range(self.block_count)
|
||||||
|
if i % attn_period == attn_offset
|
||||||
|
]
|
||||||
|
self._ssm_layers = [
|
||||||
|
i for i in range(self.block_count)
|
||||||
|
if i not in self._attn_layers
|
||||||
|
]
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._mamba_model_class.set_vocab(self)
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
# Set the mamba-type parameters
|
||||||
|
self._mamba_model_class.set_gguf_parameters(self)
|
||||||
|
|
||||||
|
# TODO: All the rest!
|
||||||
|
|
||||||
|
def modify_tensors(
|
||||||
|
self, data_torch: Tensor, name: str, bid: int | None
|
||||||
|
) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# Determine whether this is a mamaba layer or an attention layer
|
||||||
|
if bid in self._ssm_layers:
|
||||||
|
for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors(
|
||||||
|
self, data_torch, name, bid
|
||||||
|
):
|
||||||
|
yield mamba_new_name, data_torch
|
||||||
|
elif bid in self._attn_layers:
|
||||||
|
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
|
||||||
|
self, data_torch, name, bid
|
||||||
|
):
|
||||||
|
yield llama_new_name, data_torch
|
||||||
|
else:
|
||||||
|
yield name, data_torch
|
||||||
|
|
||||||
|
|
||||||
@Model.register("CohereForCausalLM")
|
@Model.register("CohereForCausalLM")
|
||||||
class CommandR2Model(Model):
|
class CommandR2Model(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
||||||
|
|
|
@ -241,6 +241,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
RWKV6 = auto()
|
RWKV6 = auto()
|
||||||
MAMBA = auto()
|
MAMBA = auto()
|
||||||
MAMBA2 = auto()
|
MAMBA2 = auto()
|
||||||
|
JAMBA = auto()
|
||||||
XVERSE = auto()
|
XVERSE = auto()
|
||||||
COMMAND_R = auto()
|
COMMAND_R = auto()
|
||||||
DBRX = auto()
|
DBRX = auto()
|
||||||
|
@ -405,6 +406,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.RWKV6: "rwkv6",
|
MODEL_ARCH.RWKV6: "rwkv6",
|
||||||
MODEL_ARCH.MAMBA: "mamba",
|
MODEL_ARCH.MAMBA: "mamba",
|
||||||
MODEL_ARCH.MAMBA2: "mamba2",
|
MODEL_ARCH.MAMBA2: "mamba2",
|
||||||
|
MODEL_ARCH.JAMBA: "jamba",
|
||||||
MODEL_ARCH.XVERSE: "xverse",
|
MODEL_ARCH.XVERSE: "xverse",
|
||||||
MODEL_ARCH.COMMAND_R: "command-r",
|
MODEL_ARCH.COMMAND_R: "command-r",
|
||||||
MODEL_ARCH.DBRX: "dbrx",
|
MODEL_ARCH.DBRX: "dbrx",
|
||||||
|
@ -1035,6 +1037,31 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.SSM_NORM,
|
MODEL_TENSOR.SSM_NORM,
|
||||||
MODEL_TENSOR.SSM_OUT,
|
MODEL_TENSOR.SSM_OUT,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.JAMBA: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.SSM_IN,
|
||||||
|
MODEL_TENSOR.SSM_CONV1D,
|
||||||
|
MODEL_TENSOR.SSM_DT,
|
||||||
|
MODEL_TENSOR.SSM_A,
|
||||||
|
MODEL_TENSOR.SSM_D,
|
||||||
|
MODEL_TENSOR.SSM_NORM,
|
||||||
|
MODEL_TENSOR.SSM_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
],
|
||||||
MODEL_ARCH.XVERSE: [
|
MODEL_ARCH.XVERSE: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
|
|
@ -101,7 +101,7 @@ class TensorNameMap:
|
||||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||||
"h.{bid}.input_layernorm", # bloom
|
"h.{bid}.input_layernorm", # bloom
|
||||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
|
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba
|
||||||
"layers.{bid}.attention_norm", # llama-pth
|
"layers.{bid}.attention_norm", # llama-pth
|
||||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||||
"model.layers.{bid}.ln1", # yi
|
"model.layers.{bid}.ln1", # yi
|
||||||
|
@ -241,6 +241,7 @@ class TensorNameMap:
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||||
|
"model.layers.{bid}.pre_ff_layernorm.weight", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
# Post feed-forward norm
|
# Post feed-forward norm
|
||||||
|
@ -293,6 +294,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||||
|
"model.layers.{bid}.feed_forward.up_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
|
@ -325,6 +327,7 @@ class TensorNameMap:
|
||||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||||
|
"model.layers.{bid}.feed_forward.gate_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||||
|
@ -365,6 +368,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||||
|
"model.layers.{bid}.feed_forward.down_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
|
@ -413,11 +417,13 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.SSM_IN: (
|
MODEL_TENSOR.SSM_IN: (
|
||||||
"model.layers.{bid}.in_proj",
|
"model.layers.{bid}.in_proj",
|
||||||
"backbone.layers.{bid}.mixer.in_proj",
|
"backbone.layers.{bid}.mixer.in_proj",
|
||||||
|
"model.layers.{bid}.mamba.in_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_CONV1D: (
|
MODEL_TENSOR.SSM_CONV1D: (
|
||||||
"model.layers.{bid}.conv1d",
|
"model.layers.{bid}.conv1d",
|
||||||
"backbone.layers.{bid}.mixer.conv1d",
|
"backbone.layers.{bid}.mixer.conv1d",
|
||||||
|
"model.layers.{bid}.mamba.conv1d", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_X: (
|
MODEL_TENSOR.SSM_X: (
|
||||||
|
@ -428,25 +434,30 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.SSM_DT: (
|
MODEL_TENSOR.SSM_DT: (
|
||||||
"model.layers.{bid}.dt_proj",
|
"model.layers.{bid}.dt_proj",
|
||||||
"backbone.layers.{bid}.mixer.dt_proj",
|
"backbone.layers.{bid}.mixer.dt_proj",
|
||||||
|
"model.layers.{bid}.mamba.dt_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_A: (
|
MODEL_TENSOR.SSM_A: (
|
||||||
"model.layers.{bid}.A_log",
|
"model.layers.{bid}.A_log",
|
||||||
"backbone.layers.{bid}.mixer.A_log",
|
"backbone.layers.{bid}.mixer.A_log",
|
||||||
|
"model.layers.{bid}.mamba.A_log", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_D: (
|
MODEL_TENSOR.SSM_D: (
|
||||||
"model.layers.{bid}.D",
|
"model.layers.{bid}.D",
|
||||||
"backbone.layers.{bid}.mixer.D",
|
"backbone.layers.{bid}.mixer.D",
|
||||||
|
"model.layers.{bid}.mamba.D", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_NORM: (
|
MODEL_TENSOR.SSM_NORM: (
|
||||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||||
|
"model.layers.{bid}.mamba.norm", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_OUT: (
|
MODEL_TENSOR.SSM_OUT: (
|
||||||
"model.layers.{bid}.out_proj",
|
"model.layers.{bid}.out_proj",
|
||||||
"backbone.layers.{bid}.mixer.out_proj",
|
"backbone.layers.{bid}.mixer.out_proj",
|
||||||
|
"model.layers.{bid}.mamba.out_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_W1: (
|
MODEL_TENSOR.TIME_MIX_W1: (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue