persimmon : use rope over whole Qcur/Kcur

This commit is contained in:
Galunid 2023-11-28 22:08:50 +01:00
parent 3e73d31d9c
commit 3e28686d7f

View file

@ -4348,105 +4348,62 @@ struct llm_build_context {
struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
cb(tmpqkv_perm, "tmpqkv", il); cb(tmpqkv_perm, "tmpqkv", il);
struct ggml_tensor * tmpq = ggml_view_3d( struct ggml_tensor * Qcur = ggml_view_3d(
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
ggml_element_size(tmpqkv_perm) * n_embd_head, ggml_element_size(tmpqkv_perm) * n_embd_head,
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
0 0
); );
cb(tmpq, "tmpq", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * tmpk = ggml_view_3d( struct ggml_tensor * Kcur = ggml_view_3d(
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
ggml_element_size(tmpqkv_perm) * n_embd_head, ggml_element_size(tmpqkv_perm) * n_embd_head,
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
); );
cb(tmpk, "tmpk", il); cb(Kcur, "Kcur", il);
// Q/K Layernorm // Q/K Layernorm
tmpq = llm_build_norm(ctx0, tmpq, hparams, Qcur = llm_build_norm(ctx0, Qcur, hparams,
model.layers[il].attn_q_norm, model.layers[il].attn_q_norm,
model.layers[il].attn_q_norm_b, model.layers[il].attn_q_norm_b,
LLM_NORM, cb, il); LLM_NORM, cb, il);
cb(tmpq, "tmpq", il); cb(Qcur, "Qcur", il);
tmpk = llm_build_norm(ctx0, tmpk, hparams, Kcur = llm_build_norm(ctx0, Kcur, hparams,
model.layers[il].attn_k_norm, model.layers[il].attn_k_norm,
model.layers[il].attn_k_norm_b, model.layers[il].attn_k_norm_b,
LLM_NORM, cb, il); LLM_NORM, cb, il);
cb(tmpk, "tmpk", il); cb(Kcur, "Kcur", il);
// RoPE the first n_rot of q/k, pass the other half, and concat. // RoPE the first n_rot of q/k, pass the other half, and concat.
struct ggml_tensor * qrot = ggml_view_3d( struct ggml_tensor * qrot = ggml_view_3d(
ctx0, tmpq, n_rot, n_head, n_tokens, ctx0, Qcur, n_embd_head, n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head, ggml_element_size(Qcur) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head, ggml_element_size(Qcur) * n_embd_head * n_head,
0 0
); );
cb(qrot, "qrot", il); cb(qrot, "qrot", il);
struct ggml_tensor * krot = ggml_view_3d( struct ggml_tensor * krot = ggml_view_3d(
ctx0, tmpk, n_rot, n_head, n_tokens, ctx0, Kcur, n_embd_head, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head, ggml_element_size(Kcur) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head, ggml_element_size(Kcur) * n_embd_head * n_head,
0 0
); );
cb(krot, "krot", il); cb(krot, "krot", il);
// get the second half of tmpq, e.g tmpq[n_rot:, :, :] Qcur = ggml_rope_custom(
struct ggml_tensor * qpass = ggml_view_3d(
ctx0, tmpq, n_rot, n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head,
ggml_element_size(tmpq) * n_rot
);
cb(qpass, "qpass", il);
struct ggml_tensor * kpass = ggml_view_3d(
ctx0, tmpk, n_rot, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head,
ggml_element_size(tmpk) * n_rot
);
cb(kpass, "kpass", il);
struct ggml_tensor * qrotated = ggml_rope_custom(
ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx, ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(qrotated, "qrotated", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * krotated = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx, ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(krotated, "krotated", il);
// ggml currently only supports concatenation on dim=2
// so we need to permute qrot, qpass, concat, then permute back.
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
cb(qrotated, "qrotated", il);
krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
cb(krotated, "krotated", il);
qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
cb(qpass, "qpass", il);
kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
cb(kpass, "kpass", il);
struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3));
cb(Q, "Q", il);
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_view_3d( struct ggml_tensor * Vcur = ggml_view_3d(
@ -4462,7 +4419,7 @@ struct llm_build_context {
// TODO: not tested, could be broken // TODO: not tested, could be broken
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }