Added initial support for DeepseekV2ForCausalLM.
This commit is contained in:
parent
e1b40ac3b9
commit
c8c353f88d
4 changed files with 482 additions and 5 deletions
|
@ -2389,6 +2389,80 @@ class JinaBertV2Model(BertModel):
|
|||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
|
||||
@Model.register("DeepseekV2ForCausalLM")
|
||||
class DeepseekV2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "yarn":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
|
||||
|
||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length(hparams["v_head_dim"])
|
||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def write_tensors(self):
|
||||
super().write_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum):
|
|||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -181,6 +182,12 @@ class MODEL_TENSOR(IntEnum):
|
|||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
ATTN_Q_A = auto()
|
||||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
|
@ -217,6 +224,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -259,6 +267,12 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
|
||||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -743,6 +757,32 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -779,6 +819,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
|
|
@ -255,6 +255,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
|
@ -283,6 +284,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
|
@ -317,6 +319,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
|
@ -380,6 +383,30 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_B: (
|
||||
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_B: (
|
||||
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
),
|
||||
}
|
||||
|
||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||
|
@ -398,7 +425,7 @@ class TensorNameMap:
|
|||
if tensor not in MODEL_TENSORS[arch]:
|
||||
continue
|
||||
# TODO: make this configurable
|
||||
n_experts = 60
|
||||
n_experts = 160
|
||||
for xid in range(n_experts):
|
||||
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
|
||||
self.mapping[tensor_name] = (tensor, tensor_name)
|
||||
|
|
340
llama.cpp
340
llama.cpp
|
@ -110,7 +110,7 @@
|
|||
#endif
|
||||
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_EXPERTS 60
|
||||
#define LLAMA_MAX_EXPERTS 160
|
||||
|
||||
//
|
||||
// logging
|
||||
|
@ -229,6 +229,7 @@ enum llm_arch {
|
|||
LLM_ARCH_COMMAND_R,
|
||||
LLM_ARCH_DBRX,
|
||||
LLM_ARCH_OLMO,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -266,6 +267,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_COMMAND_R, "command-r" },
|
||||
{ LLM_ARCH_DBRX, "dbrx" },
|
||||
{ LLM_ARCH_OLMO, "olmo" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
@ -476,6 +478,12 @@ enum llm_tensor {
|
|||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_ATTN_Q_A,
|
||||
LLM_TENSOR_ATTN_Q_B,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||
|
@ -1052,6 +1060,34 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
|
||||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
@ -1895,6 +1931,8 @@ struct llama_layer {
|
|||
struct ggml_tensor * attn_k_norm_b;
|
||||
struct ggml_tensor * attn_out_norm;
|
||||
struct ggml_tensor * attn_out_norm_b;
|
||||
struct ggml_tensor * attn_q_a_norm;
|
||||
struct ggml_tensor * attn_kv_a_norm;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
|
@ -1902,6 +1940,10 @@ struct llama_layer {
|
|||
struct ggml_tensor * wv;
|
||||
struct ggml_tensor * wo;
|
||||
struct ggml_tensor * wqkv;
|
||||
struct ggml_tensor * wq_a;
|
||||
struct ggml_tensor * wq_b;
|
||||
struct ggml_tensor * wkv_a_mqa;
|
||||
struct ggml_tensor * wkv_b;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq;
|
||||
|
@ -4261,6 +4303,11 @@ static void llm_load_hparams(
|
|||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
|
@ -4920,8 +4967,6 @@ static bool llm_load_tensors(
|
|||
throw std::runtime_error("model has expert layers but no expert layers are used");
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
|
||||
|
||||
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
||||
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
||||
ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
|
||||
|
@ -6060,6 +6105,67 @@ static bool llm_load_tensors(
|
|||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
{
|
||||
// TODO maybe move some of these to hparams
|
||||
const uint32_t n_shared_experts = 2;
|
||||
const uint32_t moe_intermediate_size = 1536;
|
||||
const uint32_t q_lora_rank = 1536;
|
||||
const uint32_t kv_lora_rank = 512;
|
||||
const uint32_t first_k_dense_replace = 1;
|
||||
|
||||
// kept original names of these parameters from HF transformers code for clarity
|
||||
const uint32_t qk_rope_head_dim = hparams.n_rot;
|
||||
const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
|
||||
layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
|
||||
|
||||
layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
|
||||
layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k});
|
||||
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim});
|
||||
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (qk_nope_head_dim + hparams.n_embd_head_v)});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
if ((uint32_t) i < first_k_dense_replace) {
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
} else {
|
||||
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
|
||||
GGML_ASSERT(hparams.n_expert > 0);
|
||||
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||
|
||||
// MoE branch
|
||||
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert});
|
||||
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {moe_intermediate_size, n_embd, n_expert});
|
||||
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, moe_intermediate_size, n_expert});
|
||||
|
||||
// Shared expert branch
|
||||
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts});
|
||||
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { moe_intermediate_size * n_shared_experts, n_embd});
|
||||
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, moe_intermediate_size * n_shared_experts});
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
@ -6700,7 +6806,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
}
|
||||
|
||||
|
@ -10779,6 +10885,227 @@ struct llm_build_context {
|
|||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_deepseek2() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||
int32_t n_tokens = this->n_tokens;
|
||||
|
||||
// TODO maybe move some of these to hparams
|
||||
const uint32_t first_k_dense_replace = 1;
|
||||
const uint32_t kv_lora_rank = 512;
|
||||
|
||||
// kept original names of these parameters from HF transformers code for clarity
|
||||
const uint32_t qk_rope_head_dim = hparams.n_rot;
|
||||
const uint32_t qk_nope_head_dim = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// {n_embd, n_tokens}
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
|
||||
struct ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
|
||||
cb(q, "q", il);
|
||||
|
||||
q = llm_build_norm(ctx0, q, hparams,
|
||||
model.layers[il].attn_q_a_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(q, "q", il);
|
||||
|
||||
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
||||
cb(q, "q", il);
|
||||
|
||||
// split into {n_head * qk_nope_head_dim, n_tokens}
|
||||
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0);
|
||||
// and {n_head * qk_rope_head_dim, n_tokens}
|
||||
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, qk_rope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * qk_nope_head_dim);
|
||||
|
||||
q_nope = ggml_cont(ctx0, q_nope);
|
||||
cb(q_nope, "q_nope", il);
|
||||
|
||||
q_pe = ggml_cont(ctx0, q_pe);
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// {n_embd, kv_lora_rank + qk_rope_head_dim} * {n_embd, n_tokens} -> {kv_lora_rank + qk_rope_head_dim, n_tokens}
|
||||
struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
||||
cb(compressed_kv_pe, "compressed_kv_pe", il);
|
||||
|
||||
// split into {kv_lora_rank, n_tokens}
|
||||
struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0);
|
||||
// and {qk_rope_head_dim, n_tokens}
|
||||
struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, qk_rope_head_dim, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank);
|
||||
|
||||
k_pe = ggml_cont(ctx0, k_pe);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
|
||||
model.layers[il].attn_kv_a_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(compressed_kv, "compressed_kv", il);
|
||||
|
||||
// {kv_lora_rank, n_head * (qk_nope_head_dim + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (qk_nope_head_dim + n_embd_head_v), n_tokens}
|
||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv);
|
||||
cb(kv, "kv", il);
|
||||
|
||||
// split into {n_head * qk_nope_head_dim, n_tokens}
|
||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), 0);
|
||||
// and {n_head * n_embd_head_v, n_tokens}
|
||||
struct ggml_tensor * value_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * qk_nope_head_dim);
|
||||
|
||||
value_states = ggml_dup(ctx0, value_states);
|
||||
cb(value_states, "value_states", il);
|
||||
|
||||
value_states = ggml_reshape_2d(ctx0, value_states, hparams.n_embd_head_v * n_head, n_tokens);
|
||||
|
||||
q_pe = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, q_pe, qk_rope_head_dim, n_head, n_tokens), inp_pos,
|
||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// shared RoPE key
|
||||
k_pe = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens), inp_pos,
|
||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
struct ggml_tensor * query_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens);
|
||||
cb(query_states, "query_states", il);
|
||||
query_states = ggml_set_inplace(ctx0, query_states, q_nope, query_states->nb[1], query_states->nb[2], query_states->nb[3], 0);
|
||||
query_states = ggml_set_inplace(ctx0, query_states, q_pe, query_states->nb[1], query_states->nb[2], query_states->nb[3], ggml_element_size(query_states) * qk_nope_head_dim);
|
||||
|
||||
k_pe = ggml_repeat(ctx0, k_pe, q_pe);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
struct ggml_tensor * key_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens);
|
||||
cb(key_states, "key_states", il);
|
||||
key_states = ggml_set_inplace(ctx0, key_states, k_nope, key_states->nb[1], key_states->nb[2], key_states->nb[3], 0);
|
||||
key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * qk_nope_head_dim);
|
||||
|
||||
// TODO see if we can avoid these operations by permuting
|
||||
// rows/columns of some model tensors during model conversion
|
||||
query_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, query_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens)));
|
||||
cb(query_states, "query_states", il);
|
||||
|
||||
query_states = ggml_reshape_3d(ctx0, query_states, hparams.n_embd_head_k, n_head, n_tokens);
|
||||
cb(query_states, "query_states", il);
|
||||
|
||||
key_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, key_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens)));
|
||||
cb(key_states, "key_states", il);
|
||||
|
||||
key_states = ggml_reshape_3d(ctx0, key_states, hparams.n_embd_head_k, n_head, n_tokens);
|
||||
cb(key_states, "key_states", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(hparams.n_embd_head_k)), cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
if ((uint32_t) il < first_k_dense_replace) {
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * moe_out =
|
||||
llm_build_moe_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
cb, il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// FFN shared expert
|
||||
{
|
||||
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up_shexp, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
moe_out = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(moe_out, "ffn_out", il);
|
||||
|
||||
cur = moe_out;
|
||||
}
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
|
@ -10993,6 +11320,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_olmo();
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
{
|
||||
result = llm.build_deepseek2();
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
@ -16008,6 +16339,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_XVERSE:
|
||||
case LLM_ARCH_COMMAND_R:
|
||||
case LLM_ARCH_OLMO:
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue