convert : add support for DeepSeek V3 model
This commit is contained in:
parent
a45433ba20
commit
0061955a06
5 changed files with 45 additions and 0 deletions
|
@ -687,6 +687,9 @@ class Model:
|
|||
if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1":
|
||||
# ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct
|
||||
res = "megrez"
|
||||
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
|
||||
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
|
||||
res = "deepseek-v3"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
@ -3831,6 +3834,7 @@ class DeepseekModel(Model):
|
|||
|
||||
|
||||
@Model.register("DeepseekV2ForCausalLM")
|
||||
@Model.register("DeepseekV3ForCausalLM")
|
||||
class DeepseekV2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
||||
|
@ -3852,6 +3856,15 @@ class DeepseekV2Model(Model):
|
|||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SIGMOID)
|
||||
elif hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
|
@ -3864,6 +3877,16 @@ class DeepseekV2Model(Model):
|
|||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# rename e_score_correction_bias tensors
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
||||
# skip Multi-Token Prediction (MTP) layers
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
match = re.match(r"model.layers.(\d+)", name)
|
||||
if match and int(match.group(1)) >= block_count:
|
||||
return []
|
||||
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
|
|
|
@ -107,6 +107,7 @@ models = [
|
|||
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
|
||||
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
|
||||
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
|
||||
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -102,6 +102,8 @@ class Keys:
|
|||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||
EXPERT_WEIGHTS_FUNC = "{arch}.expert_weights_func"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||
|
@ -312,6 +314,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
FFN_GATE_SHEXP = auto()
|
||||
FFN_DOWN_SHEXP = auto()
|
||||
FFN_UP_SHEXP = auto()
|
||||
FFN_EXPERT_WEIGHTS_B = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
|
@ -496,6 +499,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
|
@ -1276,6 +1280,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B,
|
||||
],
|
||||
MODEL_ARCH.CHATGLM : [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
@ -1576,6 +1581,11 @@ class GGMLQuantizationType(IntEnum):
|
|||
TQ2_0 = 35
|
||||
|
||||
|
||||
class ExpertWeightsFuncType(IntEnum):
|
||||
SOFTMAX = 1
|
||||
SIGMOID = 2
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from .constants import (
|
|||
RopeScalingType,
|
||||
PoolingType,
|
||||
TokenType,
|
||||
ExpertWeightsFuncType,
|
||||
)
|
||||
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
|
@ -715,6 +716,12 @@ class GGUFWriter:
|
|||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_weights_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_weights_func(self, value: ExpertWeightsFuncType) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_WEIGHTS_FUNC.format(arch=self.arch), value.value)
|
||||
|
||||
def add_swin_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
@ -276,6 +276,10 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: (
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
MODEL_TENSOR.FFN_UP: (
|
||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue