diff --git a/ggml-metal.m b/ggml-metal.m index 07535828d..204ccea1b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2573,7 +2573,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // half8x8 kernel - if (ne01 > 1) { + if (ne01 > 1 || (ne00%128 != 0)) { const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! @@ -2603,8 +2603,13 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - //const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - const int64_t nsg = 8; + const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; // require power of 2 //{ diff --git a/ggml-metal.metal b/ggml-metal.metal index 7709865c9..63a5a175d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2575,21 +2575,20 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[Q][D8]; + half4 mq[Q][D4]; for (short j = 0; j < Q; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[j][i] = sq4[j*T4 + i]; } } // pointer to the mask - //device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); - device const half * mp = (device const half *) (mask + iq1*nb31); + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); - //half mscale(scale); + half mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2599,79 +2598,45 @@ kernel void kernel_flash_attn_ext_vec_f16( break; } - // Q*K^T - //{ - // for (short cc = 0; cc < C/4; ++cc) { - // half4 mqk[Q]; - // for (short j = 0; j < Q; ++j) { - // mqk[j] = 0.0h; - // } - - // device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - // for (short i = tiisg; i < D4; i += NW) { - // half4x4 mk; - // mk[0] = pk4[i + 0*(nb11/8)]; - // mk[1] = pk4[i + 1*(nb11/8)]; - // mk[2] = pk4[i + 2*(nb11/8)]; - // mk[3] = pk4[i + 3*(nb11/8)]; - - // for (short j = 0; j < Q; ++j) { - // mqk[j] += mq[j][i] * mk; - // } - // } - - // // reduce the results from the threads in the simdgroup - // simdgroup_barrier(mem_flags::mem_none); - - // for (short i = NW/2; i > 0; i /= 2) { - // if (tiisg < i) { - // for (short j = 0; j < Q; ++j) { - // mqk[j] += simd_shuffle_down(mqk[j], i); - // } - // } - - // simdgroup_barrier(mem_flags::mem_none); - // } - - // // mqk = mqk*scale + mask - // if (tiisg == 0) { - // for (short j = 0; j < Q; ++j) { - // half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; - // mqk[j] = mqk[j]*mscale + mm; - - // ss4[j*T4 + cc] = mqk[j]; - // } - // } - // } - //} - // Q*K^T { - for (short cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk[Q]; - for (short j = 0; j < Q; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); - } + for (short cc = 0; cc < C/4; ++cc) { + half4 mqk[Q] = { [0 ... Q-1] = 0.0h }; - device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; for (short j = 0; j < Q; ++j) { - simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + mqk[j] += mq[j][i] * mk; } } - // mqk = mqk*scale + mask + // reduce the results from the threads in the simdgroup for (short j = 0; j < Q; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); + mqk[j] += simd_shuffle_down(mqk[j], 16); + mqk[j] += simd_shuffle_down(mqk[j], 8); + mqk[j] += simd_shuffle_down(mqk[j], 4); + mqk[j] += simd_shuffle_down(mqk[j], 2); + mqk[j] += simd_shuffle_down(mqk[j], 1); + } - simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + // mqk = mqk*scale + mask + if (tiisg == 0) { + for (short j = 0; j < Q; ++j) { + half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; + mqk[j] = mqk[j]*mscale + mm; + + ss4[j*T4 + cc] = mqk[j]; + } } } } @@ -2701,26 +2666,26 @@ kernel void kernel_flash_attn_ext_vec_f16( ss[tiisg*T + C + tiisg] = ms[tiisg]; } - //threadgroup_barrier(mem_flags::mem_threadgroup); - // O = diag(ms)*O for (short j = 0; j < Q; ++j) { - //simdgroup_half8x8 mm; - //simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); half mm(ss[j*T + C + j]); - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_multiply(lo[j][i], mm, lo[j][i]); +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; lo[j][i/NW] = lo[j][i/NW]*mm; } } // O = O + (Q*K^T)*V { +#pragma unroll for (short cc = 0; cc < C; ++cc) { device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); - for (short i = tiisg; i < D4; i += NW) { +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; for (short j = 0; j < Q; ++j) { lo[j][i/NW] += pv4[i]*ss[j*T + cc]; } @@ -2738,15 +2703,16 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - // store results to shared memory for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - sr4[i] = lo[j][i/NW]; + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[j][ii/NW]; } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { @@ -2805,10 +2771,6 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;