diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4e6c0f60c..43e61b500 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -687,6 +687,9 @@ class Model: if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1": # ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct res = "megrez" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" if res is None: logger.warning("\n") @@ -3831,6 +3834,7 @@ class DeepseekModel(Model): @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3852,6 +3856,15 @@ class DeepseekV2Model(Model): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3864,6 +3877,16 @@ class DeepseekV2Model(Model): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index fea23ddb4..56edc64a7 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -107,6 +107,7 @@ models = [ {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 273370370..13b77979b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -102,6 +102,8 @@ class Keys: EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_WEIGHTS_FUNC = "{arch}.expert_weights_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -312,6 +314,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_EXPERT_WEIGHTS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -496,6 +499,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -1276,6 +1280,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B, ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, @@ -1576,6 +1581,11 @@ class GGMLQuantizationType(IntEnum): TQ2_0 = 35 +class ExpertWeightsFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + + # TODO: add GGMLFileType from ggml_ftype in ggml.h diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3023b539a..a0dadeaf8 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ from .constants import ( RopeScalingType, PoolingType, TokenType, + ExpertWeightsFuncType, ) from .quants import quant_shape_from_byte_shape @@ -715,6 +716,12 @@ class GGUFWriter: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_weights_func(self, value: ExpertWeightsFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_WEIGHTS_FUNC.format(arch=self.arch), value.value) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7009a11d4..f48769ad1 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -276,6 +276,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), + MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox