From 0afe47fa5fdda0ff9191ca70241a9fe88364d8cc Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 31 Jan 2024 15:43:42 -0500 Subject: [PATCH] fix naive implementation --- tests/test-flash-attention.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index d4457a53e..1f779b0d4 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -207,7 +207,9 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0])); kq = ggml_soft_max(ctx0, kq); - kq = ggml_mul_mat(ctx0, model.v, kq); + kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq); + kq = ggml_permute (ctx0, kq, 0, 2, 1, 3); + //kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]); ggml_build_forward_expand(gf, kq); }