From e51778de5e67d867cc39802536cb995b245b73e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 16:10:15 +0300 Subject: [PATCH] metal : switch to parallel reduce --- ggml-metal.m | 16 +++- ggml-metal.metal | 196 +++++++++++++++++++++++++---------------------- 2 files changed, 119 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0f405b112..6106bc7e3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2615,13 +2615,23 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + //const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsg = 8; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + // require power of 2 + //{ + // int64_t nsgm = 1; + // while (nsgm < nsg) { + // nsgm *= 2; + // } + // GGML_ASSERT(nsg == nsgm); + //} + + const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 533b9fef6..404bd16e0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2457,6 +2457,8 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. + template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, @@ -2500,6 +2502,7 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iq1 = tgpig[0]*Q; const short D4 = D/4; + const short D8 = D/8; const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) @@ -2510,6 +2513,7 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) half4 lo[Q][D4]; @@ -2545,7 +2549,7 @@ kernel void kernel_flash_attn_ext_vec_f16( { half S[Q] = { [0 ... Q-1] = 0.0h }; - half M[Q] = { [0 ... Q-1] = -INFINITY }; + half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF }; // assume K and V are same shape const short ne22 = ne12; @@ -2571,21 +2575,21 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[Q][D4]; + simdgroup_half8x8 mq[Q][D8]; for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); - mq[j][i] = sq4[j*T4 + i]; + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); } } // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + //device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - //simdgroup_half8x8 mscale(scale); - half mscale(scale); + simdgroup_half8x8 mscale(scale); + //half mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2595,55 +2599,83 @@ kernel void kernel_flash_attn_ext_vec_f16( break; } + // Q*K^T + //{ + // for (short cc = 0; cc < C/4; ++cc) { + // half4 mqk[Q]; + // for (short j = 0; j < Q; ++j) { + // mqk[j] = 0.0h; + // } + + // device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + // for (short i = tiisg; i < D4; i += NW) { + // half4x4 mk; + // mk[0] = pk4[i + 0*(nb11/8)]; + // mk[1] = pk4[i + 1*(nb11/8)]; + // mk[2] = pk4[i + 2*(nb11/8)]; + // mk[3] = pk4[i + 3*(nb11/8)]; + + // for (short j = 0; j < Q; ++j) { + // mqk[j] += mq[j][i] * mk; + // } + // } + + // // reduce the results from the threads in the simdgroup + // simdgroup_barrier(mem_flags::mem_none); + + // for (short i = NW/2; i > 0; i /= 2) { + // if (tiisg < i) { + // for (short j = 0; j < Q; ++j) { + // mqk[j] += simd_shuffle_down(mqk[j], i); + // } + // } + + // simdgroup_barrier(mem_flags::mem_none); + // } + + // // mqk = mqk*scale + mask + // if (tiisg == 0) { + // for (short j = 0; j < Q; ++j) { + // half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; + // mqk[j] = mqk[j]*mscale + mm; + + // ss4[j*T4 + cc] = mqk[j]; + // } + // } + // } + //} + // Q*K^T { - for (short cc = 0; cc < C/4; ++cc) { - half4 mqk[Q]; + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk[Q]; for (short j = 0; j < Q; ++j) { - mqk[j] = 0.0h; + mqk[j] = make_filled_simdgroup_matrix(0.h); } - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (short i = tiisg; i < D4; i += NW) { - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose for (short j = 0; j < Q; ++j) { - mqk[j] += mq[j][i] * mk; + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); } } - // reduce the results from the threads in the simdgroup - simdgroup_barrier(mem_flags::mem_none); - - for (short i = NW/2; i > 0; i /= 2) { - if (tiisg < i) { - for (short j = 0; j < Q; ++j) { - mqk[j] += simd_shuffle_down(mqk[j], i); - } - } - - simdgroup_barrier(mem_flags::mem_none); - } - // mqk = mqk*scale + mask - if (tiisg == 0) { - for (short j = 0; j < Q; ++j) { - half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; - mqk[j] = mqk[j]*mscale + mm; + for (short j = 0; j < Q; ++j) { + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - ss4[j*T4 + cc] = mqk[j]; - } + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); } } } - simdgroup_barrier(mem_flags::mem_threadgroup); - // online softmax half ms[Q]; @@ -2655,8 +2687,8 @@ kernel void kernel_flash_attn_ext_vec_f16( M[j] = simd_max(max(M[j], s)); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = exp(m - M[j]); + const half vs = exp(s - M[j]); S[j] = S[j]*ms[j] + simd_sum(vs); @@ -2706,75 +2738,59 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { - half S = { 0.0h }; - half M = { -INFINITY }; + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); - - // each simdgroup stores its output to shared memory, reusing sq - if (sgitg == sg) { - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - sq4[j*T4 + i] = lo[j][i]; - } - } + // store results to shared memory + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + sr4[i] = lo[j][i]; } + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + if (tiisg == 0) { + for (short j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + r*SH + 0]; - // the first simdgroup accumulates the results from the other simdgroups - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + r*SH + 1]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; + const half M = max(M0, M1); - M = max(M0, M1); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); - const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + const half S = S0*ms0 + S1*ms1; - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { ss[j*T + 0] = S; ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + r*SH] = ms1; } } + } - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg < r) { for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - half4 t = sq4[j*T4 + i]; - half ms0 = ss[j*T + C + j]; - half ms1 = ss[j*T + C + j + sg*SH]; + const half ms0 = ss[j*T + C + j]; + const half ms1 = ss[j*T + C + j + r*SH]; - lo[j][i] = lo[j][i]*ms0 + t*ms1; + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short i = tiisg; i < D4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; } } } - } - // store result to shared memory (reuse sq) - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - sq4[j*T4 + i] = lo[j][i]; - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); } - threadgroup_barrier(mem_flags::mem_threadgroup); - device float4 * dst4 = (device float4 *) dst; // final rescale with 1/S and store to global memory @@ -2783,7 +2799,7 @@ kernel void kernel_flash_attn_ext_vec_f16( const half S = ss[j*T + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S; } } }