llama: dbrx: rename tensor to actual meaning. Fix normalization in graph. Permute expert tensors to the llama.cpp layout
This commit is contained in:
parent
8e22688401
commit
35dce3e145
4 changed files with 42 additions and 41 deletions
|
@ -1523,18 +1523,21 @@ class DbrxModel(Model):
|
|||
n_embd = self.hparams["d_model"]
|
||||
|
||||
# Specific behavior for experts tensors: reshape to 3D and add suffix .weight
|
||||
exp_tensor_names = {"ffn.experts.mlp.v1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_GATE_EXPS
|
||||
"ffn.experts.mlp.w1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_DOWN_EXPS
|
||||
"ffn.experts.mlp.w2": (n_ff, n_embd, n_expert)} # LLM_TENSOR_FFN_UP_EXPS
|
||||
exp_tensor_names = {"ffn.experts.mlp.v1": (2, 1, 3), # LLM_TENSOR_FFN_GATE_EXPS(n_embd, n_ff, n_expert)
|
||||
"ffn.experts.mlp.w2": (1, 2, 3), # LLM_TENSOR_FFN_DOWN_EXPS(n_ff, n_embd, n_expert)
|
||||
"ffn.experts.mlp.w1": (2, 1, 3)} # LLM_TENSOR_FFN_UP_EXPS (n_embd, n_ff, n_expert)
|
||||
experts = False
|
||||
for exp_tensor_name in exp_tensor_names.keys():
|
||||
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
|
||||
experts = True
|
||||
expert_reshape = exp_tensor_names[exp_tensor_name][::-1]
|
||||
expert_permute = exp_tensor_names[exp_tensor_name][::-1]
|
||||
break
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
if experts:
|
||||
data_torch = data_torch.view(n_expert, n_ff, n_embd)
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
@ -1557,7 +1560,7 @@ class DbrxModel(Model):
|
|||
|
||||
# Reshape experts tensors from 2D to 3D as expected by GeLU
|
||||
if experts and n_dims == 2:
|
||||
data = data.reshape(expert_reshape)
|
||||
data = data.transpose(expert_permute)
|
||||
n_dims = len(data.shape)
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
|
|
|
@ -646,9 +646,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM_2,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
|
|
|
@ -101,8 +101,7 @@ class TensorNameMap:
|
|||
|
||||
# Attention norm 2
|
||||
MODEL_TENSOR.ATTN_NORM_2: (
|
||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||
),
|
||||
|
||||
# Attention query-key-value
|
||||
|
@ -155,23 +154,24 @@ class TensorNameMap:
|
|||
|
||||
# Attention output
|
||||
MODEL_TENSOR.ATTN_OUT: (
|
||||
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
||||
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
|
||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
"model.layers.{bid}.self_attn.dense", # persimmon
|
||||
"h.{bid}.attn.c_proj", # gpt2
|
||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
"model.layers.{bid}.attention.wo", # internlm2
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
|
||||
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
||||
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
|
||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
"model.layers.{bid}.self_attn.dense", # persimmon
|
||||
"h.{bid}.attn.c_proj", # gpt2
|
||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
"model.layers.{bid}.attention.wo", # internlm2
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
|
@ -306,10 +306,10 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: (
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_IN: (
|
||||
|
|
20
llama.cpp
20
llama.cpp
|
@ -938,7 +938,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
|
@ -4687,16 +4687,16 @@ static bool llm_load_tensors(
|
|||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2,"weight", i), {n_embd});
|
||||
|
||||
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS,"weight", i), {n_embd, n_ff, n_expert});
|
||||
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS,"weight", i), {n_ff, n_embd, n_expert});
|
||||
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
|
||||
|
||||
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd, n_embd});
|
||||
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
|
@ -7132,7 +7132,6 @@ struct llm_build_context {
|
|||
struct ggml_tensor * Vcur = nullptr;
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
|
||||
|
@ -7161,10 +7160,9 @@ struct llm_build_context {
|
|||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].layer_out_norm, model.layers[il].bo,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
|
||||
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
|
@ -7181,11 +7179,6 @@ struct llm_build_context {
|
|||
// feed-forward network
|
||||
// MoE branch
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].attn_norm_2, NULL,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||
cb(logits, "ffn_moe_logits", il);
|
||||
|
||||
|
@ -7243,6 +7236,11 @@ struct llm_build_context {
|
|||
cur = moe_out;
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].layer_out_norm, NULL,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "layer_out_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue