metal : add BS=1 kernel for flash attention (wip)

This commit is contained in:
Georgi Gerganov 2024-04-05 13:29:48 +03:00
parent 89961dea87
commit d15898481a
No known key found for this signature in database
GPG key ID: BF970631944C16B7
2 changed files with 501 additions and 32 deletions

View file

@ -179,6 +179,12 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@ -613,12 +619,18 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@ -2509,19 +2521,36 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (ne00) { if (ne01 > 1) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; switch (ne00) {
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
default: case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
{ default:
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); {
GGML_METAL_LOG_ERROR("add template specialization for this size\n"); GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_ASSERT(false && "add template specialization for this size"); GGML_METAL_LOG_ERROR("add template specialization for this size\n");
} GGML_ASSERT(false && "add template specialization for this size");
}
}
} else {
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
default:
{
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
GGML_ASSERT(false && "add template specialization for this size");
}
}
} }
// TODO: extend if necessary // TODO: extend if necessary
@ -2555,24 +2584,48 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27]; [encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! // half8x8 kernel
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! if (ne01 > 1) {
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0); GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps) // simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
// half1x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
//const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
const int64_t nsg = 1;
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break; } break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:

View file

@ -2457,6 +2457,422 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_vec_f16(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const short iq3 = tgpig[2];
const short iq2 = tgpig[1];
const short iq1 = tgpig[0]*Q;
const short D4 = D/4;
const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half)
const short T = D + nsg*SH; // shared memory size per query in (half)
const short T4 = T/4; // shared memory size per query in (half4)
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half4 lo[Q][D4];
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
sq4[j*T4 + i] = (half4) q4[i];
} else {
sq4[j*T4 + i] = 0.0h;
}
}
}
// zero out lo
for (short j = 0; j < Q; ++j) {
for (short i = 0; i < D4; ++i) {
lo[j][i] = 0.0h;
}
}
// zero out shared memory SH
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH/4; i += NW) {
ss4[j*T4 + i] = 0.0h;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
half S[Q] = { [0 ... Q-1] = 0.0h };
half M[Q] = { [0 ... Q-1] = -INFINITY };
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
const uint nb21 = nb11;
const uint nb22 = nb12;
const uint nb23 = nb13;
// broadcast
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
// k indices
const short ik2 = iq2 / rk2;
const short ik3 = iq3 / rk3;
// v indices
const short iv2 = iq2 / rv2;
const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
half4 mq[Q][D4];
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < D4; i += NW) {
//simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
mq[j][i] = sq4[j*T4 + i];
}
}
// pointer to the mask
device const half * mp = (device const half *) (mask + iq1*nb31);
// prepare diagonal scale matrix
//simdgroup_half8x8 mscale(scale);
half mscale(scale);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= ne11) {
break;
}
// Q*K^T
{
for (short cc = 0; cc < C; ++cc) {
half mqk[Q];
for (short j = 0; j < Q; ++j) {
mqk[j] = 0.0h;
}
//device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13));
for (short i = tiisg; i < D4; i += NW) {
//simdgroup_half8x8 mk;
half4 mk;
//simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
mk = pk4[i];
for (short j = 0; j < Q; ++j) {
//simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
mqk[j] += dot(mq[j][i], mk);
}
}
// reduce the results from the threads in the simdgroup
simdgroup_barrier(mem_flags::mem_none);
for (short i = NW/2; i > 0; i /= 2) {
if (tiisg < i) {
for (short j = 0; j < Q; ++j) {
mqk[j] += simd_shuffle_down(mqk[j], i);
}
}
simdgroup_barrier(mem_flags::mem_none);
}
// mqk = mqk*scale + mask
if (tiisg == 0) {
for (short j = 0; j < Q; ++j) {
//simdgroup_half8x8 mm;
//simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
//simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
//simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
half mm = mp[j*(nb31/sizeof(half)) + ic + cc];
mqk[j] = mqk[j]*mscale + mm;
ss[j*T + cc] = mqk[j];
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// used to detect blocks full of -INF
half smax = -INFINITY;
// online softmax
if (C == 32) {
half ms[Q];
for (short j = 0; j < Q; ++j) {
const short p = tiisg;
const half m = M[j];
const half s = ss[j*T + p];
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg];
}
} else {
half ms[Q];
for (short j = 0; j < Q; ++j) {
const half m = M[j];
for (short p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
smax = max(smax, s);
M[j] = max(M[j], s);
}
smax = simd_max(smax);
M[j] = simd_max(M[j]);
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
// local sum
half ls = 0.0h;
for (short p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
S[j] = S[j]*ms[j] + simd_sum(ls);
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg];
}
}
// skip -INF blocks
if (smax == -INFINITY) {
continue;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// O = diag(ms)*O
for (short j = 0; j < Q; ++j) {
//simdgroup_half8x8 mm;
//simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
half mm(ss[j*T + C + j]);
for (short i = tiisg; i < D4; i += NW) {
//simdgroup_multiply(lo[j][i], mm, lo[j][i]);
lo[j][i] = lo[j][i]*mm;
}
}
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C; ++cc) {
//device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
//for (short i = 0; i < D8; ++i) {
// simdgroup_half8x8 mk;
// simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
// for (short j = 0; j < Q8; ++j) {
// simdgroup_half8x8 mv;
// simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
// simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
// }
//}
for (short i = tiisg; i < D4; i += NW) {
half4 mk = pv4[i];
for (short j = 0; j < Q; ++j) {
lo[j][i] += mk*ss[j*T + cc];
}
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) {
if (tiisg == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
}
}
// reduce the warps sequentially
//for (short sg = 1; sg < nsg; ++sg) {
// half S = { 0.0h };
// half M = { -INFINITY };
// threadgroup_barrier(mem_flags::mem_threadgroup);
// // each simdgroup stores its output to shared memory, reusing sq
// if (sgitg == sg) {
// for (short j = 0; j < Q8; ++j) {
// for (short i = 0; i < D8; ++i) {
// simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
// }
// }
// }
// threadgroup_barrier(mem_flags::mem_threadgroup);
// // the first simdgroup accumulates the results from the other simdgroups
// if (sgitg == 0) {
// for (short j = 0; j < Q; ++j) {
// const half S0 = ss[j*T + 0];
// const half S1 = ss[j*T + sg*SH + 0];
// const half M0 = ss[j*T + 1];
// const half M1 = ss[j*T + sg*SH + 1];
// M = max(M0, M1);
// const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
// const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
// S = S0*ms0 + S1*ms1;
// if (tiisg == 0) {
// ss[j*T + 0] = S;
// ss[j*T + 1] = M;
// ss[j*T + C + j ] = ms0;
// ss[j*T + C + j + sg*SH] = ms1;
// }
// }
// // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
// for (short j = 0; j < Q8; ++j) {
// simdgroup_half8x8 t;
// simdgroup_half8x8 ms0;
// simdgroup_half8x8 ms1;
// simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
// simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
// for (short i = 0; i < D8; ++i) {
// simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
// simdgroup_multiply(t, ms1, t);
// simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t);
// }
// }
// }
//}
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < D4; i += NW) {
//simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
sq4[j*T4 + i] = lo[j][i];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
device float4 * dst4 = (device float4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (short i = tiisg; i < D4; i += NW) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
}
}
}
}
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<64, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<80, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<96, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<112, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,
device half * dst, device half * dst,