llama: dbrx: fix k scale

This commit is contained in:
Pierrick HYMBERT 2024-04-08 10:43:36 +02:00
parent 71f9e479aa
commit 52c6276e12

View file

@ -7128,6 +7128,10 @@ struct llm_build_context {
// self-attention // self-attention
{ {
struct ggml_tensor * Qcur = nullptr;
struct ggml_tensor * Kcur = nullptr;
struct ggml_tensor * Vcur = nullptr;
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il); cb(cur, "wqkv", il);
@ -7136,28 +7140,32 @@ struct llm_build_context {
cb(cur, "wqkv_clamped", il); cb(cur, "wqkv_clamped", il);
} }
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Kcur = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
cb(Vcur, "Vcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].layer_out_norm, NULL, model.layers[il].layer_out_norm, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens,kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
} }
if (il == n_layer - 1) { if (il == n_layer - 1) {
@ -14770,7 +14778,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
// the pairs of head values are offset by n_rot/2 // the pairs of head values are offset by n_rot/2
case LLM_ARCH_FALCON: case LLM_ARCH_FALCON:
case LLM_ARCH_GROK: case LLM_ARCH_GROK:
case LLM_ARCH_DBRX: // FIXME REVIEW @ggerganov I am not sure what to put here case LLM_ARCH_DBRX:
case LLM_ARCH_PERSIMMON: case LLM_ARCH_PERSIMMON:
case LLM_ARCH_BERT: case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT: