llama : use n_embd_head_v instead of n_embd_head_k when reshaping kqv
This commit is contained in:
parent
9afdffe70e
commit
f15e933fb1
1 changed files with 2 additions and 2 deletions
|
@ -6655,7 +6655,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
|
||||
} else {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
@ -6700,7 +6700,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue