llama: dbrx: rename tensor to actual meaning. Fix normalization in graph. Permute expert tensors to the llama.cpp layout
This commit is contained in:
parent
8e22688401
commit
35dce3e145
4 changed files with 42 additions and 41 deletions
|
@ -1523,18 +1523,21 @@ class DbrxModel(Model):
|
||||||
n_embd = self.hparams["d_model"]
|
n_embd = self.hparams["d_model"]
|
||||||
|
|
||||||
# Specific behavior for experts tensors: reshape to 3D and add suffix .weight
|
# Specific behavior for experts tensors: reshape to 3D and add suffix .weight
|
||||||
exp_tensor_names = {"ffn.experts.mlp.v1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_GATE_EXPS
|
exp_tensor_names = {"ffn.experts.mlp.v1": (2, 1, 3), # LLM_TENSOR_FFN_GATE_EXPS(n_embd, n_ff, n_expert)
|
||||||
"ffn.experts.mlp.w1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_DOWN_EXPS
|
"ffn.experts.mlp.w2": (1, 2, 3), # LLM_TENSOR_FFN_DOWN_EXPS(n_ff, n_embd, n_expert)
|
||||||
"ffn.experts.mlp.w2": (n_ff, n_embd, n_expert)} # LLM_TENSOR_FFN_UP_EXPS
|
"ffn.experts.mlp.w1": (2, 1, 3)} # LLM_TENSOR_FFN_UP_EXPS (n_embd, n_ff, n_expert)
|
||||||
experts = False
|
experts = False
|
||||||
for exp_tensor_name in exp_tensor_names.keys():
|
for exp_tensor_name in exp_tensor_names.keys():
|
||||||
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
|
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
|
||||||
experts = True
|
experts = True
|
||||||
expert_reshape = exp_tensor_names[exp_tensor_name][::-1]
|
expert_permute = exp_tensor_names[exp_tensor_name][::-1]
|
||||||
break
|
break
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
|
if experts:
|
||||||
|
data_torch = data_torch.view(n_expert, n_ff, n_embd)
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
data_torch = data_torch.to(torch.float32)
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
@ -1557,7 +1560,7 @@ class DbrxModel(Model):
|
||||||
|
|
||||||
# Reshape experts tensors from 2D to 3D as expected by GeLU
|
# Reshape experts tensors from 2D to 3D as expected by GeLU
|
||||||
if experts and n_dims == 2:
|
if experts and n_dims == 2:
|
||||||
data = data.reshape(expert_reshape)
|
data = data.transpose(expert_permute)
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
|
|
||||||
# if f32 desired, convert any float16 to float32
|
# if f32 desired, convert any float16 to float32
|
||||||
|
|
|
@ -646,9 +646,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
MODEL_TENSOR.OUTPUT,
|
MODEL_TENSOR.OUTPUT,
|
||||||
MODEL_TENSOR.ATTN_QKV,
|
|
||||||
MODEL_TENSOR.ATTN_NORM,
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
MODEL_TENSOR.ATTN_NORM_2,
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
MODEL_TENSOR.FFN_GATE_INP,
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
MODEL_TENSOR.FFN_GATE_EXP,
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
|
|
@ -102,7 +102,6 @@ class TensorNameMap:
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
MODEL_TENSOR.ATTN_NORM_2: (
|
MODEL_TENSOR.ATTN_NORM_2: (
|
||||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention query-key-value
|
# Attention query-key-value
|
||||||
|
@ -171,7 +170,8 @@ class TensorNameMap:
|
||||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||||
"model.layers.{bid}.attention.wo", # internlm2
|
"model.layers.{bid}.attention.wo", # internlm2
|
||||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
|
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||||
|
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output norm
|
# Attention output norm
|
||||||
|
@ -309,7 +309,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_IN: (
|
MODEL_TENSOR.SSM_IN: (
|
||||||
|
|
20
llama.cpp
20
llama.cpp
|
@ -938,7 +938,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||||
{ LLM_TENSOR_OUTPUT, "output" },
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
@ -4687,16 +4687,16 @@ static bool llm_load_tensors(
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2,"weight", i), {n_embd});
|
|
||||||
|
|
||||||
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
||||||
|
layer.wo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||||
|
|
||||||
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||||
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS,"weight", i), {n_embd, n_ff, n_expert});
|
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS,"weight", i), {n_embd, n_ff, n_expert});
|
||||||
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS,"weight", i), {n_ff, n_embd, n_expert});
|
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS,"weight", i), {n_ff, n_embd, n_expert});
|
||||||
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
|
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
|
||||||
|
|
||||||
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd, n_embd});
|
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_BAICHUAN:
|
case LLM_ARCH_BAICHUAN:
|
||||||
|
@ -7132,7 +7132,6 @@ struct llm_build_context {
|
||||||
struct ggml_tensor * Vcur = nullptr;
|
struct ggml_tensor * Vcur = nullptr;
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||||
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
|
|
||||||
cb(cur, "wqkv", il);
|
cb(cur, "wqkv", il);
|
||||||
|
|
||||||
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
||||||
|
@ -7161,10 +7160,9 @@ struct llm_build_context {
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||||
model.layers[il].layer_out_norm, model.layers[il].bo,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7181,11 +7179,6 @@ struct llm_build_context {
|
||||||
// feed-forward network
|
// feed-forward network
|
||||||
// MoE branch
|
// MoE branch
|
||||||
{
|
{
|
||||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
|
||||||
model.layers[il].attn_norm_2, NULL,
|
|
||||||
LLM_NORM, cb, il);
|
|
||||||
cb(cur, "ffn_norm", il);
|
|
||||||
|
|
||||||
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||||
cb(logits, "ffn_moe_logits", il);
|
cb(logits, "ffn_moe_logits", il);
|
||||||
|
|
||||||
|
@ -7243,6 +7236,11 @@ struct llm_build_context {
|
||||||
cur = moe_out;
|
cur = moe_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.layers[il].layer_out_norm, NULL,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "layer_out_norm", il);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue