llama: dbrx: rename tensor to actual meaning. Fix normalization in graph. Permute expert tensors to the llama.cpp layout

This commit is contained in:
Pierrick HYMBERT 2024-04-08 14:02:08 +02:00
parent 8e22688401
commit 35dce3e145
4 changed files with 42 additions and 41 deletions

View file

@ -1523,18 +1523,21 @@ class DbrxModel(Model):
n_embd = self.hparams["d_model"] n_embd = self.hparams["d_model"]
# Specific behavior for experts tensors: reshape to 3D and add suffix .weight # Specific behavior for experts tensors: reshape to 3D and add suffix .weight
exp_tensor_names = {"ffn.experts.mlp.v1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_GATE_EXPS exp_tensor_names = {"ffn.experts.mlp.v1": (2, 1, 3), # LLM_TENSOR_FFN_GATE_EXPS(n_embd, n_ff, n_expert)
"ffn.experts.mlp.w1": (n_embd, n_ff, n_expert), # LLM_TENSOR_FFN_DOWN_EXPS "ffn.experts.mlp.w2": (1, 2, 3), # LLM_TENSOR_FFN_DOWN_EXPS(n_ff, n_embd, n_expert)
"ffn.experts.mlp.w2": (n_ff, n_embd, n_expert)} # LLM_TENSOR_FFN_UP_EXPS "ffn.experts.mlp.w1": (2, 1, 3)} # LLM_TENSOR_FFN_UP_EXPS (n_embd, n_ff, n_expert)
experts = False experts = False
for exp_tensor_name in exp_tensor_names.keys(): for exp_tensor_name in exp_tensor_names.keys():
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1: if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
experts = True experts = True
expert_reshape = exp_tensor_names[exp_tensor_name][::-1] expert_permute = exp_tensor_names[exp_tensor_name][::-1]
break break
old_dtype = data_torch.dtype old_dtype = data_torch.dtype
if experts:
data_torch = data_torch.view(n_expert, n_ff, n_embd)
# convert any unsupported data types to float32 # convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32): if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32) data_torch = data_torch.to(torch.float32)
@ -1557,7 +1560,7 @@ class DbrxModel(Model):
# Reshape experts tensors from 2D to 3D as expected by GeLU # Reshape experts tensors from 2D to 3D as expected by GeLU
if experts and n_dims == 2: if experts and n_dims == 2:
data = data.reshape(expert_reshape) data = data.transpose(expert_permute)
n_dims = len(data.shape) n_dims = len(data.shape)
# if f32 desired, convert any float16 to float32 # if f32 desired, convert any float16 to float32

View file

@ -646,9 +646,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2, MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_DOWN_EXP,

View file

@ -102,7 +102,6 @@ class TensorNameMap:
# Attention norm 2 # Attention norm 2
MODEL_TENSOR.ATTN_NORM_2: ( MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn", # falcon40b "transformer.h.{bid}.ln_attn", # falcon40b
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
), ),
# Attention query-key-value # Attention query-key-value
@ -171,7 +170,8 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2 "model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert "encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
), ),
# Attention output norm # Attention output norm
@ -309,7 +309,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.LayerNorm", # bert "encoder.layer.{bid}.output.LayerNorm", # bert
"encoder.layers.{bid}.norm2", # nomic-bert "encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
), ),
MODEL_TENSOR.SSM_IN: ( MODEL_TENSOR.SSM_IN: (

View file

@ -938,7 +938,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
@ -4687,16 +4687,16 @@ static bool llm_load_tensors(
auto & layer = model.layers[i]; auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2,"weight", i), {n_embd});
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
layer.wo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS,"weight", i), {n_embd, n_ff, n_expert}); layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS,"weight", i), {n_embd, n_ff, n_expert});
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS,"weight", i), {n_ff, n_embd, n_expert}); layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS,"weight", i), {n_ff, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd, n_embd}); layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
} }
} break; } break;
case LLM_ARCH_BAICHUAN: case LLM_ARCH_BAICHUAN:
@ -7132,7 +7132,6 @@ struct llm_build_context {
struct ggml_tensor * Vcur = nullptr; struct ggml_tensor * Vcur = nullptr;
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
cb(cur, "wqkv", il); cb(cur, "wqkv", il);
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
@ -7161,10 +7160,9 @@ struct llm_build_context {
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].layer_out_norm, model.layers[il].bo, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
} }
if (il == n_layer - 1) { if (il == n_layer - 1) {
@ -7181,11 +7179,6 @@ struct llm_build_context {
// feed-forward network // feed-forward network
// MoE branch // MoE branch
{ {
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].attn_norm_2, NULL,
LLM_NORM, cb, il);
cb(cur, "ffn_norm", il);
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il); cb(logits, "ffn_moe_logits", il);
@ -7243,6 +7236,11 @@ struct llm_build_context {
cur = moe_out; cur = moe_out;
} }
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].layer_out_norm, NULL,
LLM_NORM, cb, il);
cb(cur, "layer_out_norm", il);
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);