ggml : fix num dimensions in ggml_flash_attn_ext

This commit is contained in:
Georgi Gerganov 2024-04-22 12:50:26 +03:00
parent a39217d428
commit cb76d747d1
No known key found for this signature in database
GPG key ID: BF970631944C16B7

2
ggml.c
View file

@ -6321,7 +6321,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
// permute(0, 2, 1, 3)
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale };
ggml_set_op_params(result, params, sizeof(params));