From 1a88565b4489381923aa0c9a6741badfb6766b23 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Apr 2024 15:52:49 +0300 Subject: [PATCH] metal : clean-up kernel code --- ggml-metal.metal | 142 ++++++++++++++--------------------------------- 1 file changed, 43 insertions(+), 99 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 1ed5632b4..32cbef9dc 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2121,7 +2121,7 @@ typedef void (flash_attn_ext_f16_t)( ushort sgitg[[simdgroup_index_in_threadgroup]]); // ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2178,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[Q8][D8]; + simdgroup_half8x8 lo[D8]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { @@ -2194,10 +2194,8 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - lo[j][i] = make_filled_simdgroup_matrix(0.0h); - } + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); } // zero out shared memory SH @@ -2229,20 +2227,18 @@ kernel void kernel_flash_attn_ext_f16( const short rv3 = ne03/ne23; // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[Q8][D8]; + simdgroup_half8x8 mq[D8]; - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); - } + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); } // pointer to the mask @@ -2262,10 +2258,7 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk[Q8]; - for (short j = 0; j < Q8; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); - } + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2273,19 +2266,15 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - for (short j = 0; j < Q8; ++j) { - simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); - } + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } // mqk = mqk*scale + mask - for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); - simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false); - } + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); } } @@ -2293,7 +2282,7 @@ kernel void kernel_flash_attn_ext_f16( float smax = -INFINITY; // online softmax - if (C == 32) { + { float ms[Q]; for (short j = 0; j < Q; ++j) { @@ -2314,45 +2303,6 @@ kernel void kernel_flash_attn_ext_f16( ss[j*TF + p] = vs; } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; - } - } else { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const float m = M[j]; - - for (short p = tiisg; p < C; p += NW) { - const float s = ss[j*TF + p]; - - smax = max(smax, s); - M[j] = max(M[j], s); - } - - smax = simd_max(smax); - M[j] = simd_max(M[j]); - - ms[j] = exp(m - M[j]); - - // local sum - float ls = 0.0h; - - for (short p = tiisg; p < C; p += NW) { - const float s = ss[j*TF + p]; - - const float vs = exp(s - M[j]); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; - } - - S[j] = S[j]*ms[j] + simd_sum(ls); - } - // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { ss[tiisg*TF + C + tiisg] = ms[tiisg]; @@ -2365,12 +2315,12 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - for (short j = 0; j < Q8; ++j) { + { simdgroup_float8x8 mm; - simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false); + simdgroup_load(mm, ss + C, TF, 0, false); for (short i = 0; i < D8; ++i) { - simdgroup_multiply(lo[j][i], mm, lo[j][i]); + simdgroup_multiply(lo[i], mm, lo[i]); } } @@ -2383,12 +2333,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - for (short j = 0; j < Q8; ++j) { - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false); + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); - simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); - } + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); } } } @@ -2412,10 +2360,8 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - } + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } @@ -2447,19 +2393,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short j = 0; j < Q8; ++j) { + { simdgroup_half8x8 t; simdgroup_float8x8 ms0; simdgroup_float8x8 ms1; - simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false); - simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false); + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); + simdgroup_load (t, sq + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); } } } @@ -2467,10 +2413,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (short j = 0; j < Q8; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - } + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } @@ -2488,14 +2432,14 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, device const char * k, @@ -2539,7 +2483,7 @@ kernel void kernel_flash_attn_ext_vec_f16( const short D4 = D/4; const short NW = N_SIMDWIDTH; - const short SH = (C + 1); // shared memory per simdgroup in (half) + const short SH = (C + Q); // shared memory per simdgroup in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half) @@ -2763,8 +2707,8 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; kernel void kernel_cpy_f16_f16( device const half * src0,