fix kernel
This commit is contained in:
parent
3b0f74b428
commit
b1479dfbc5
2 changed files with 56 additions and 49 deletions
|
@ -201,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
|
|||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
if(!model.naive_attn) {
|
||||
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0]));
|
||||
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0]));
|
||||
ggml_build_forward_expand(gf, result);
|
||||
} else {
|
||||
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue