llama : add Deepseek MoE v1 & GigaChat models (#10827)
* Add deepseek v1 arch & gigachat template * improve template code * add readme * delete comments * remove comment * fix format * lint llama.cpp * fix order of deepseek and deepseek2, move gigachat temlate to the end of func * fix order of deepseek and deepseek2 in constants; mark shared exp as deepseek arch need * remove comments * move deepseek above deepseek2 * change placement of gigachat chat template
This commit is contained in:
parent
87cf323cef
commit
a0974156f3
7 changed files with 423 additions and 3 deletions
|
@ -664,6 +664,9 @@ class Model:
|
|||
if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65":
|
||||
# ref: https://huggingface.co/sentence-transformers/stsb-roberta-base
|
||||
res = "roberta-bpe"
|
||||
if chkhsh == "ad851be1dba641f2e3711822f816db2c265f788b37c63b4e1aeacb9ee92de8eb":
|
||||
# ref: https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct
|
||||
res = "gigachat"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
@ -3427,6 +3430,97 @@ class ArcticModel(Model):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("DeepseekForCausalLM")
|
||||
class DeepseekModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if "head_dim" in hparams:
|
||||
rope_dim = hparams["head_dim"]
|
||||
else:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_weights_scale(1.0)
|
||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
@staticmethod
|
||||
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
|
||||
if n_head_kv is not None and n_head != n_head_kv:
|
||||
n_head = n_head_kv
|
||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||
data_torch = DeepseekModel.permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
||||
data_torch = DeepseekModel.permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("DeepseekV2ForCausalLM")
|
||||
class DeepseekV2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue