From b57af0c9dd7e92c2767b165a035fb434c5465ca3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 17:47:01 +0300 Subject: [PATCH] metal : initial FA vec kernel --- ggml-metal.m | 2 +- ggml-metal.metal | 205 +++++++++++++++++++---------------------------- 2 files changed, 82 insertions(+), 125 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 204ccea1b..2680cf21c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2603,7 +2603,7 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); int64_t nsg = 1; while (nsg <= nsgt) { diff --git a/ggml-metal.metal b/ggml-metal.metal index 63a5a175d..ca0f57a96 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2459,7 +2459,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f #define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, device const char * k, @@ -2499,12 +2499,12 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iq3 = tgpig[2]; const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; + const short iq1 = tgpig[0]; const short D4 = D/4; const short D8 = D/8; const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + const short SH = (C + 1); // shared memory per simdgroup in (half) const short T = D + nsg*SH; // shared memory size per query in (half) const short T4 = T/4; // shared memory size per query in (half4) @@ -2513,43 +2513,37 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[Q][D4/NW]; + half4 lo[D4/NW]; // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { - sq4[j*T4 + i] = (half4) q4[i]; - } else { - sq4[j*T4 + i] = 0.0h; - } + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; } } // zero out lo - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - lo[j][i/NW] = 0.0h; - } + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; } // zero out shared memory SH - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < SH/4; i += NW) { - ss4[j*T4 + i] = 0.0h; - } + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; } threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { [0 ... Q-1] = 0.0h }; - half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF }; + half S = { 0.0h }; + half M = { -HALF_MAX_HALF }; // assume K and V are same shape const short ne22 = ne12; @@ -2575,21 +2569,16 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[Q][D4]; + half4 mq[D4]; - for (short j = 0; j < Q; ++j) { - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[j][i] = sq4[j*T4 + i]; - } + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; } // pointer to the mask device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); - // prepare diagonal scale matrix - half mscale(scale); - // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { @@ -2600,8 +2589,9 @@ kernel void kernel_flash_attn_ext_vec_f16( // Q*K^T { +#pragma unroll for (short cc = 0; cc < C/4; ++cc) { - half4 mqk[Q] = { [0 ... Q-1] = 0.0h }; + half4 mqk = { 0.0h }; device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2615,100 +2605,81 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = pk4[i + 2*(nb11/8)]; mk[3] = pk4[i + 3*(nb11/8)]; - for (short j = 0; j < Q; ++j) { - mqk[j] += mq[j][i] * mk; - } + mqk += mq[i] * mk; } // reduce the results from the threads in the simdgroup - for (short j = 0; j < Q; ++j) { - mqk[j] += simd_shuffle_down(mqk[j], 16); - mqk[j] += simd_shuffle_down(mqk[j], 8); - mqk[j] += simd_shuffle_down(mqk[j], 4); - mqk[j] += simd_shuffle_down(mqk[j], 2); - mqk[j] += simd_shuffle_down(mqk[j], 1); - } + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); // mqk = mqk*scale + mask if (tiisg == 0) { - for (short j = 0; j < Q; ++j) { - half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; - mqk[j] = mqk[j]*mscale + mm; + half4 mm = mp4[ic/4 + cc]; + mqk = mqk*scale + mm; - ss4[j*T4 + cc] = mqk[j]; - } + ss4[cc] = mqk; } } } // online softmax - half ms[Q]; - - for (short j = 0; j < Q; ++j) { + { const short p = tiisg; - const half m = M[j]; - const half s = ss[j*T + p]; + const half m = M; + const half s = ss[p]; - M[j] = simd_max(max(M[j], s)); + M = simd_max(max(M, s)); - ms[j] = exp(m - M[j]); - const half vs = exp(s - M[j]); + const half ms = exp(m - M); + const half vs = exp(s - M); - S[j] = S[j]*ms[j] + simd_sum(vs); + S = S*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; - } - - // O = diag(ms)*O - for (short j = 0; j < Q; ++j) { - half mm(ss[j*T + C + j]); + ss[p] = vs; + // O = diag(ms)*O #pragma unroll for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - lo[j][i/NW] = lo[j][i/NW]*mm; + lo[i/NW] *= ms; } } // O = O + (Q*K^T)*V { #pragma unroll - for (short cc = 0; cc < C; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); #pragma unroll for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - for (short j = 0; j < Q; ++j) { - lo[j][i/NW] += pv4[i]*ss[j*T + cc]; - } + const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; } } } + } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = 0; j < Q; ++j) { - if (tiisg == 0) { - ss[j*T + 0] = S[j]; - ss[j*T + 1] = M[j]; - } + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; } } // store results to shared memory - for (short j = 0; j < Q; ++j) { - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = lo[j][ii/NW]; - } + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2716,41 +2687,28 @@ kernel void kernel_flash_attn_ext_vec_f16( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; + + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; + + const half M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + if (tiisg == 0) { - for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + r*SH + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + r*SH + 1]; - - const half M = max(M0, M1); - - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); - - const half S = S0*ms0 + S1*ms1; - - ss[j*T + 0] = S; - ss[j*T + 1] = M; - - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + r*SH] = ms1; - } + ss[0] = S; + ss[1] = M; } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (sgitg < r) { - for (short j = 0; j < Q; ++j) { - const half ms0 = ss[j*T + C + j]; - const half ms1 = ss[j*T + C + j + r*SH]; - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short i = tiisg; i < D4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; - } + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; } } @@ -2761,18 +2719,17 @@ kernel void kernel_flash_attn_ext_vec_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; + const half S = ss[0]; - for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S; - } + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; } } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; kernel void kernel_cpy_f16_f16( device const half * src0,