diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 156ac0c53..cb1f01549 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2401,7 +2401,16 @@ class DeepseekV2Model(Model): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams + + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -2410,10 +2419,6 @@ class DeepseekV2Model(Model): self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) - self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["v_head_dim"]) - self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) - _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0a732022d..9b6e56847 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -37,11 +37,14 @@ class Keys: CONTEXT_LENGTH = "{arch}.context_length" EMBEDDING_LENGTH = "{arch}.embedding_length" BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" EXPERT_COUNT = "{arch}.expert_count" EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" @@ -55,6 +58,8 @@ class Keys: LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" CAUSAL = "{arch}.attention.causal" + Q_LORA_RANK = "{arch}.attention.q_lora_rank" + KV_LORA_RANK = "{arch}.attention.kv_lora_rank" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d5e323a52..e82f4e9ab 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -376,9 +376,15 @@ class GGUFWriter: def add_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) + def add_leading_dense_block_count(self, length: int) -> None: + self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + def add_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_expert_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) @@ -409,6 +415,9 @@ class GGUFWriter: def add_expert_used_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count) + def add_expert_shared_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) @@ -418,6 +427,12 @@ class GGUFWriter: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_q_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) + + def add_kv_lora_rank(self, length: int) -> None: + self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) diff --git a/llama.cpp b/llama.cpp index 1da020b28..560fc7acf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -288,11 +288,14 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, @@ -305,6 +308,8 @@ enum llm_kv { LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -365,11 +370,14 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, @@ -382,6 +390,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -1803,6 +1813,12 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_leading_dense_layer = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_expert_ff = 0; + uint32_t n_expert_shared = 0; + float f_norm_eps; float f_norm_rms_eps; @@ -1842,6 +1858,12 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_leading_dense_layer != other.n_leading_dense_layer) return true; + if (this->n_lora_q != other.n_lora_q) return true; + if (this->n_lora_kv != other.n_lora_kv) return true; + if (this->n_expert_ff != other.n_expert_ff) return true; + if (this->n_expert_shared != other.n_expert_shared) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -4306,6 +4328,12 @@ static void llm_load_hparams( case LLM_ARCH_DEEPSEEK2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_leading_dense_layer); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_expert_ff); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + model.type = e_model::MODEL_UNKNOWN; } break; default: (void)0; @@ -6107,16 +6135,12 @@ static bool llm_load_tensors( } break; case LLM_ARCH_DEEPSEEK2: { - // TODO maybe move some of these to hparams - const uint32_t n_shared_experts = 2; - const uint32_t moe_intermediate_size = 1536; - const uint32_t q_lora_rank = 1536; - const uint32_t kv_lora_rank = 512; - const uint32_t first_k_dense_replace = 1; - // kept original names of these parameters from HF transformers code for clarity const uint32_t qk_rope_head_dim = hparams.n_rot; const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t q_lora_rank = hparams.n_lora_q; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const uint32_t moe_intermediate_size = hparams.n_expert_ff; model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6144,7 +6168,7 @@ static bool llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - if ((uint32_t) i < first_k_dense_replace) { + if ((uint32_t) i < hparams.n_leading_dense_layer) { layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); @@ -6160,9 +6184,9 @@ static bool llm_load_tensors( layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert}); // Shared expert branch - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { moe_intermediate_size * n_shared_experts, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts}); + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, moe_intermediate_size * hparams.n_expert_shared}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { moe_intermediate_size * hparams.n_expert_shared, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, moe_intermediate_size * hparams.n_expert_shared}); } } } break; @@ -10893,13 +10917,10 @@ struct llm_build_context { // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; - // TODO maybe move some of these to hparams - const uint32_t first_k_dense_replace = 1; - const uint32_t kv_lora_rank = 512; - // kept original names of these parameters from HF transformers code for clarity const uint32_t qk_rope_head_dim = hparams.n_rot; const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -11022,7 +11043,7 @@ struct llm_build_context { struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - if ((uint32_t) il < first_k_dense_replace) { + if ((uint32_t) il < hparams.n_leading_dense_layer) { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);