feat(convert): Full pass at hparam conversion
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
246dfdba65
commit
e3525e9e50
3 changed files with 131 additions and 19 deletions
|
@ -3007,6 +3007,14 @@ class MambaModel(Model):
|
|||
class Mamba2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# n_groups and d_inner are used during reshaping
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
|
||||
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
|
||||
|
||||
def set_vocab(self):
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
# Round vocab size to next multiple of 16
|
||||
|
@ -3028,30 +3036,27 @@ class Mamba2Model(Model):
|
|||
self._set_vocab_builtin("gpt-neox", vocab_size)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
|
||||
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
|
||||
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
|
||||
n_group = self.find_hparam(["n_groups"], optional=True) or 1
|
||||
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
|
||||
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
|
||||
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
|
||||
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
# TODO: does this really matter?
|
||||
# assert d_inner == 2 * d_model
|
||||
assert d_inner % head_dim == 0
|
||||
assert self.d_inner == 2 * self.d_model
|
||||
assert self.d_inner % head_dim == 0
|
||||
|
||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||
self.gguf_writer.add_embedding_length(d_model)
|
||||
self.gguf_writer.add_embedding_length(self.d_model)
|
||||
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
||||
self.gguf_writer.add_ssm_state_size(d_state)
|
||||
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
|
||||
self.gguf_writer.add_ssm_group_count(n_group)
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.d_inner // head_dim)
|
||||
self.gguf_writer.add_ssm_group_count(self.n_group)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
@ -3083,10 +3088,7 @@ class Mamba2Model(Model):
|
|||
return data_torch.reshape((*data_torch.shape, 1))
|
||||
|
||||
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
|
||||
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
n_group = self.hparams.get("n_groups", 1)
|
||||
return data_torch.reshape((n_group, d_inner // n_group))
|
||||
return data_torch.reshape((self.n_group, self.d_inner // self.n_group))
|
||||
|
||||
return data_torch.squeeze()
|
||||
|
||||
|
@ -3099,6 +3101,11 @@ class JambaModel(Model):
|
|||
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
# Hybrid mamba models use a prefix for the mamba-specific params.
|
||||
# TODO: Extend this if the prefix(es) need to be configurable
|
||||
self.hparam_prefixes = ["mamba"]
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Determine if this is using Mamba or Mamba2
|
||||
|
@ -3130,14 +3137,73 @@ class JambaModel(Model):
|
|||
if i not in self._attn_layers
|
||||
]
|
||||
|
||||
# n_group and d_inner are used during reshape_tensors for mamaba2
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||
self.n_group = self.find_hparam(["n_groups"])
|
||||
self.d_inner = self.find_hparam(["expand"]) * self.d_model
|
||||
|
||||
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
|
||||
prefixed = []
|
||||
for pfx in self.hparam_prefixes:
|
||||
prefixed.extend(
|
||||
"_".join([pfx, k])
|
||||
for k in keys
|
||||
)
|
||||
keys = list(keys) + prefixed
|
||||
return super().find_hparam(keys, *args, **kwargs)
|
||||
|
||||
def set_vocab(self):
|
||||
self._mamba_model_class.set_vocab(self)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# Set the mamba-type parameters
|
||||
self._mamba_model_class.set_gguf_parameters(self)
|
||||
|
||||
# TODO: All the rest!
|
||||
## General Params ##
|
||||
self.gguf_writer.add_embedding_length(self.d_model)
|
||||
self.gguf_writer.add_mamba_version(self._mamba_version)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
## Mamba mixer params ##
|
||||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
|
||||
self.gguf_writer.add_ssm_group_count(self.n_group)
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["time_step_rank", "dt_rank"]))
|
||||
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
||||
self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"]))
|
||||
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
|
||||
self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False)
|
||||
self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False)
|
||||
self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"]))
|
||||
# TODO: I think this will always be true if available?
|
||||
# "use_mamba_kernels": true,
|
||||
|
||||
## Attention params ##
|
||||
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
|
||||
|
||||
## Feed Forward Params ##
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
## Validation ##
|
||||
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
|
||||
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
|
||||
# TODO: Support MoE FFN configurations
|
||||
# "num_experts"
|
||||
# "num_experts_per_tok"
|
||||
# "expert_layer_offset"
|
||||
# "expert_layer_period"
|
||||
assert self.hparams.get("num_experts") in [None, 1], "MoE not currently supported"
|
||||
|
||||
## UNUSED?? ##
|
||||
# "tie_word_embeddings" <-- Implied by presence of output weights
|
||||
# "router_aux_loss_coef" <-- Only used if outputting router logits
|
||||
# "num_logits_to_keep" <-- Always only keep final token logits
|
||||
# "output_router_logits" <-- Never output router logits since only doing generate
|
||||
# "use_cache" <-- KV Cache always enabled
|
||||
# "sliding_window" <-- Used for flash attention in transformers
|
||||
|
||||
def modify_tensors(
|
||||
self, data_torch: Tensor, name: str, bid: int | None
|
||||
|
@ -3157,6 +3223,22 @@ class JambaModel(Model):
|
|||
yield name, data_torch
|
||||
|
||||
|
||||
def reshape_tensors(
|
||||
self,
|
||||
data_torch: Tensor,
|
||||
new_name: str, bid: int | None,
|
||||
) -> Tensor:
|
||||
if bid in self._ssm_layers:
|
||||
return self._mamba_model_class.reshape_tensors(
|
||||
self, data_torch, new_name, bid
|
||||
)
|
||||
elif bid in self._attn_layers:
|
||||
return self._transformer_model_class.reshape_tensors(
|
||||
self, data_torch, new_name, bid
|
||||
)
|
||||
return data_torch
|
||||
|
||||
|
||||
@Model.register("CohereForCausalLM")
|
||||
class CommandR2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
||||
|
|
|
@ -151,6 +151,15 @@ class Keys:
|
|||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||
GROUP_COUNT = "{arch}.ssm.group_count"
|
||||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
||||
HEAD_COUNT = "{arch}.ssm.head_count"
|
||||
HEAD_DIM = "{arch}.ssm.head_dim"
|
||||
CHUNK_SIZE = "{arch}.ssm.chunk_size"
|
||||
CONV_BIAS = "{arch}.ssm.conv_bias"
|
||||
PROJ_BIAS = "{arch}.ssm.proj_bias"
|
||||
|
||||
class HybridMamba:
|
||||
MAMBA_VERSION = "{arch}.mamba.version"
|
||||
ATTN_LAYER_INDICES = "{arch}.attn.layers"
|
||||
|
||||
class WKV:
|
||||
HEAD_SIZE = "{arch}.wkv.head_size"
|
||||
|
|
|
@ -790,6 +790,27 @@ class GGUFWriter:
|
|||
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.HEAD_COUNT.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_head_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_chunk_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_conv_bias(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.CONV_BIAS.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_proj_bias(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value)
|
||||
|
||||
def add_mamba_version(self, value: str) -> None:
|
||||
self.add_string(Keys.HybridMamba.MAMBA_VERSION.format(arch=self.arch), value)
|
||||
|
||||
def add_attn_layer_indices(self, values: list[int]) -> None:
|
||||
self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values)
|
||||
|
||||
def add_tokenizer_model(self, model: str) -> None:
|
||||
self.add_string(Keys.Tokenizer.MODEL, model)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue