llama : support Jamba

This commit is contained in:
Francis Couture-Harpin 2024-05-24 19:27:27 -04:00
parent 7e13f19fb5
commit cbc743e600
5 changed files with 606 additions and 123 deletions

View file

@ -2300,7 +2300,7 @@ class MambaModel(Model):
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
@ -2346,6 +2346,107 @@ class MambaModel(Model):
)
@Model.register("JambaForCausalLM")
class JambaModel(Model):
model_arch = gguf.MODEL_ARCH.JAMBA
def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused
return "gpt-2"
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
new_name = self.map_tensor_name(merged_name)
yield new_name, data_torch
return
new_name = self.map_tensor_name(name)
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield new_name, data_torch
# same as Mamba
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del n_dims # unused
return bid is not None and new_name in (
self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
)
@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R