feat(jamba): First pass at GGUF conversion for Jamba models

There are likely still some missing hparams, but the tensor mapping should
be correct

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-11-26 14:29:24 -07:00
parent 1ee6c482d0
commit 9a68f7537b
3 changed files with 110 additions and 3 deletions

View file

@ -516,7 +516,10 @@ class Model:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
vocab_size = max(
self.hparams.get("vocab_size", len(tokenizer.vocab)),
len(tokenizer.vocab)
)
assert max(tokenizer.vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
@ -3036,7 +3039,7 @@ class Mamba2Model(Model):
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
# assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
@ -3088,6 +3091,72 @@ class Mamba2Model(Model):
return data_torch.squeeze()
@Model.register("JambaForCausalLM")
class JambaModel(Model):
"""Jamba is a hybrid SSM + Attention model and can support either Mamba or
Mamba2 style SSMs
"""
model_arch = gguf.MODEL_ARCH.JAMBA
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Determine if this is using Mamba or Mamba2
self._mamba_version = self.hparams.get("mamba_version", "v1")
self._mamba_model_class: type[Model] = {
"v1": MambaModel,
"v2": Mamba2Model,
}.get(self._mamba_version, Model)
assert (
self._mamba_model_class is not Model
), f"Unsupported mamba_version: {self._mamba_version}"
# Use Llama conversion for attention / FF / MoE
self._transformer_model_class: type[Model] = LlamaModel
# Lists of which layers use ssm vs attention
self._attn_layers = self.hparams.get("attn_layer_indices", [])
if not self._attn_layers:
attn_period = self.hparams.get("attn_layer_period")
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
attn_offset = self.hparams.get("attn_layer_offset")
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
self._attn_layers = [
i for i in range(self.block_count)
if i % attn_period == attn_offset
]
self._ssm_layers = [
i for i in range(self.block_count)
if i not in self._attn_layers
]
def set_vocab(self):
self._mamba_model_class.set_vocab(self)
def set_gguf_parameters(self):
# Set the mamba-type parameters
self._mamba_model_class.set_gguf_parameters(self)
# TODO: All the rest!
def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
# Determine whether this is a mamaba layer or an attention layer
if bid in self._ssm_layers:
for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors(
self, data_torch, name, bid
):
yield mamba_new_name, data_torch
elif bid in self._attn_layers:
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
self, data_torch, name, bid
):
yield llama_new_name, data_torch
else:
yield name, data_torch
@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R

View file

@ -241,6 +241,7 @@ class MODEL_ARCH(IntEnum):
RWKV6 = auto()
MAMBA = auto()
MAMBA2 = auto()
JAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()
@ -405,6 +406,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.MAMBA2: "mamba2",
MODEL_ARCH.JAMBA: "jamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
@ -1035,6 +1037,31 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.JAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -101,7 +101,7 @@ class TensorNameMap:
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
@ -241,6 +241,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm.weight", # jamba
),
# Post feed-forward norm
@ -293,6 +294,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # jamba
),
MODEL_TENSOR.FFN_UP_EXP: (
@ -325,6 +327,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # jamba
),
MODEL_TENSOR.FFN_GATE_EXP: (
@ -365,6 +368,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # jamba
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@ -413,11 +417,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj", # jamba
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d", # jamba
),
MODEL_TENSOR.SSM_X: (
@ -428,25 +434,30 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj", # jamba
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log", # jamba
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D", # jamba
),
MODEL_TENSOR.SSM_NORM: (
"backbone.layers.{bid}.mixer.norm", # mamba2
"model.layers.{bid}.mamba.norm", # jamba
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.mamba.out_proj", # jamba
),
MODEL_TENSOR.TIME_MIX_W1: (