diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index d4457a53e..1f779b0d4 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -207,7 +207,9 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0])); kq = ggml_soft_max(ctx0, kq); - kq = ggml_mul_mat(ctx0, model.v, kq); + kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq); + kq = ggml_permute (ctx0, kq, 0, 2, 1, 3); + //kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]); ggml_build_forward_expand(gf, kq); }