metal : specialize for head size

This commit is contained in:
Georgi Gerganov 2024-01-21 12:01:55 +02:00
parent 52ae085750
commit b97325800a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 179 additions and 122 deletions

View file

@ -1959,6 +1959,43 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
typedef void (flash_attn_ext_f16_t)(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]);
template<int64_t D> // head size
kernel void kernel_flash_attn_ext_f16(
device const char * q,
device const char * k,
@ -2002,7 +2039,6 @@ kernel void kernel_flash_attn_ext_f16(
return;
}
const int64_t D = ne00;
const int64_t D4 = D/4;
// TODO: can we move this to the stack?
@ -2097,6 +2133,10 @@ kernel void kernel_flash_attn_ext_f16(
}
}
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,