ggml : add ggml_flash_attn_ext API

This commit is contained in:
Georgi Gerganov 2024-01-18 17:42:55 +02:00
parent ad19812cda
commit a1c004ef2e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 456 additions and 38 deletions

View file

@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
kernel void kernel_flash_attn_ext_f16(
device const half * q,
device const half * k,
device const half * v,
device const half * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant float & scale,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// TODO: implement
}
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,