From 9bd5ae09aef581b6f12f98e9ca46455d9cc3e244 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Nov 2024 22:52:33 +0200 Subject: [PATCH] wip 3 --- ggml/src/ggml-metal.metal | 148 ++++++++++++++++++++++++-------------- 1 file changed, 93 insertions(+), 55 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 80e5dcc4e..cf36eaab5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3161,11 +3161,11 @@ kernel void kernel_flash_attn_ext( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const float S0 = (s_t) ss[j*TS + 0]; - const float S1 = (s_t) ss[j*TS + sg*SH + 0]; + const float S0 = ss[j*TS + 0]; + const float S1 = ss[j*TS + sg*SH + 0]; - const float M0 = (s_t) ss[j*TS + 1]; - const float M1 = (s_t) ss[j*TS + sg*SH + 1]; + const float M0 = ss[j*TS + 1]; + const float M1 = ss[j*TS + sg*SH + 1]; M = max(M0, M1); @@ -3234,7 +3234,7 @@ kernel void kernel_flash_attn_ext( half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ - half, simdgroup_half8x8, \ + half, simdgroup_half8x8, \ half, half4, simdgroup_half8x8 #else #define S_T float @@ -3243,10 +3243,10 @@ kernel void kernel_flash_attn_ext( #define S8x8_T simdgroup_float8x8 #define FA_TYPES \ - half, half4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 #endif @@ -3297,11 +3297,28 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES +#undef S8x8_T +#undef S4x4_T +#undef S4_T +#undef S_T -// NOTE: can use half instead of float precision for some extra perf -// however, by default use F32 since the op should be mostly memory bandwidth bound -// D - head size, Q - queries per threadgroup, C - cache items per threadgroup -template +template< + typename q4_t, + typename q4x4_t, + typename k4x4_t, + typename v4x4_t, + typename s_t, // attention accumulation types + typename s4_t, + typename s4x4_t, + typename o4x4_t, + typename block_q, + short nl_k, + void (*deq_k)(device const block_q *, short, thread k4x4_t &), + short nl_v, + void (*deq_v)(device const block_q *, short, thread v4x4_t &), + short D, // head size + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( device const char * q, device const char * k, @@ -3350,37 +3367,39 @@ kernel void kernel_flash_attn_ext_vec( const short NW4 = NW/4; const short SH = C; // shared memory per simdgroup in (half) - const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short SF = sizeof(s_t)/sizeof(half); - //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention - threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 - threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results + const short T = D + SF*nsg*SH; // shared memory size per query in (half) + + //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4 + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4 + threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + SF*sgitg*SH + 1*D); // same as above but in half4 + threadgroup s4x4_t * sr4x4 = (threadgroup s4x4_t *) (shared + SF*sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - float4x4 lo[D16/NW4]; + o4x4_t lo[D16/NW4]; // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); for (short i = tiisg; i < D4; i += NW) { if (iq1 < ne01) { - sq4[i] = (half4) q4[i]; + sq4[i] = (q4_t) q4[i]; } else { - sq4[i] = 0.0h; + sq4[i] = (q4_t) (float4) 0.0f; } } // zero out lo for (short i = 0; i < D16/NW4; i += NW4) { - lo[i] = float4x4(0.0f); + lo[i] = (o4x4_t) 0.0f; } // zero out shared memory SH for (short i = tiisg; i < SH/4; i += NW) { - ss4[i] = 0.0h; + ss4[i] = (s4_t) (float4) 0.0f; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3412,10 +3431,10 @@ kernel void kernel_flash_attn_ext_vec( const short iv3 = iq3/rv3; // load the queries from shared memory into local memory - float4x4 mq[D16/NW4]; + k4x4_t mq[D16/NW4]; for (short ii = 0; ii < D16; ii += NW4) { - mq[ii/NW4] = (float4x4) sq44[ii + tx]; + mq[ii/NW4] = (k4x4_t) sq4x4[ii + tx]; } // pointer to the mask @@ -3445,7 +3464,7 @@ kernel void kernel_flash_attn_ext_vec( { // each simdgroup processes 1 query and 4 keys for (short cc = 0; cc < C/4; ++cc) { - float mqk = 0.0; + s_t mqk = 0.0; device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); @@ -3453,8 +3472,8 @@ kernel void kernel_flash_attn_ext_vec( for (short ii = 0; ii < D16; ii += NW4) { const short i = ii + tx; - float4x4 mk; - dequantize_func(pk + i/nl, i%nl, mk); + k4x4_t mk; + deq_k(pk + i/nl_k, i%nl_k, mk); mqk += dot(mq[ii/NW4][0], mk[0]) + @@ -3482,7 +3501,7 @@ kernel void kernel_flash_attn_ext_vec( mqk = logit_softcap*precise::tanh(mqk); } - mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f; + mqk += (s_t) ((mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f); ss[4*cc + ty] = mqk; } @@ -3523,16 +3542,16 @@ kernel void kernel_flash_attn_ext_vec( for (short cc = 0; cc < C/4; ++cc) { device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); - const float4x4 lss(ss[4*cc + ty]); + const s4x4_t ms(ss[4*cc + ty]); #pragma unroll for (short ii = 0; ii < D16; ii += NW4) { const short i = ii + tx; - float4x4 mv; - dequantize_func(pv4 + i/nl, i%nl, mv); + v4x4_t mv; + deq_v(pv4 + i/nl_v, i%nl_v, mv); - lo[ii/NW4] += mv*lss; + lo[ii/NW4] += mv*ms; } } } @@ -3540,8 +3559,8 @@ kernel void kernel_flash_attn_ext_vec( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) if (tiisg == 0) { - ss[0] = S; - ss[1] = M; + ss[0] = (s_t) S; + ss[1] = (s_t) M; } } @@ -3570,7 +3589,7 @@ kernel void kernel_flash_attn_ext_vec( // store results to shared memory for (short i = tiisg; i < D16; i += NW4) { - sr44[i] = lo[i/NW4]; + sr4x4[i] = lo[i/NW4]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3592,13 +3611,13 @@ kernel void kernel_flash_attn_ext_vec( const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[0] = S; - ss[1] = M; + ss[0] = (s_t) S; + ss[1] = (s_t) M; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short i = tiisg; i < D16; i += NW) { - sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1; + sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1; } } @@ -3612,26 +3631,45 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S; + dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } -typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +// NOTE: can use half instead of float precision for some extra perf +// however, by default use F32 since the op should be mostly memory bandwidth bound -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#define S_T float +#define S4_T float4 +#define S4x4_T float4x4 -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#define FA_TYPES \ + half4, half4x4, \ + float4x4, \ + float4x4, \ + float, float4, float4x4, \ + float4x4 + +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +#undef FA_TYPES +#undef S4x4_T +#undef S4_T +#undef S_T template kernel void kernel_cpy(