fix kernel

This commit is contained in:
FSSRepo 2024-01-31 12:28:48 -05:00
parent 3b0f74b428
commit b1479dfbc5
2 changed files with 56 additions and 49 deletions

View file

@ -201,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
if(!model.naive_attn) {
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0]));
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0]));
ggml_build_forward_expand(gf, result);
} else {
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);