diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 76bfedc66..f314ad20a 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + // 16*32*nsgmax + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) + int64_t nsgmax = 2; while (true) { - // 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache - // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG - const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsgmax); if (smem > device.maxThreadgroupMemoryLength) { break; } @@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node( // simdgroups per threadgroup (a.k.a. warps) const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsg); //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else { // half4x4 kernel @@ -3178,8 +3181,28 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the attention scores (nqptg x ncpsg) as f32 + // + // 2*ne00*(nsg) + // each simdgroup has a full f32 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); int64_t nsg = 1; while (nsg <= nsgt) { @@ -3187,12 +3210,12 @@ static void ggml_metal_encode_node( } nsg /= 2; - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsg); - //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 1db0f4b86..224005d0a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2876,6 +2876,7 @@ kernel void kernel_flash_attn_ext( for (short cc = 0; cc < C/8; ++cc) { simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2891,6 +2892,7 @@ kernel void kernel_flash_attn_ext( device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); if (D16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks half4x4 tmp; dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp); skv4[4*ty + tx] = tmp; @@ -3009,6 +3011,7 @@ kernel void kernel_flash_attn_ext( device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); if (D16%4 == 0) { + // no need for bound checks half4x4 tmp; dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp); skv4[4*ty + tx] = tmp; @@ -3233,14 +3236,14 @@ kernel void kernel_flash_attn_ext_vec( const short D16 = D/16; const short NW = N_SIMDWIDTH; const short NW4 = NW/4; - const short SH = (C + Q); // shared memory per simdgroup in (half) + const short SH = C; // shared memory per simdgroup in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half) //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results