fix naive implementation

This commit is contained in:
FSSRepo 2024-01-31 15:43:42 -05:00
parent b1479dfbc5
commit 0afe47fa5f

View file

@ -207,7 +207,9 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0])); kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0]));
kq = ggml_soft_max(ctx0, kq); kq = ggml_soft_max(ctx0, kq);
kq = ggml_mul_mat(ctx0, model.v, kq); kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq);
kq = ggml_permute (ctx0, kq, 0, 2, 1, 3);
//kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]);
ggml_build_forward_expand(gf, kq); ggml_build_forward_expand(gf, kq);
} }