ggml : fix GQA support in ggml_flash_attn_ext

This commit is contained in:
Georgi Gerganov 2024-01-19 20:06:26 +02:00
parent a1c004ef2e
commit fa7ebcca99
3 changed files with 23 additions and 12 deletions

View file

@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32(
}
kernel void kernel_flash_attn_ext_f16(
device const half * q,
device const half * k,
device const half * v,
device const half * mask,
device const half * q,
device const half * k,
device const half * v,
device const float * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,