diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 3a5149d2d..76bfedc66 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3187,7 +3187,7 @@ static void ggml_metal_encode_node( } nsg /= 2; - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3ae1be095..1db0f4b86 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2819,22 +2819,25 @@ kernel void kernel_flash_attn_ext( float S[Q] = { [0 ... Q-1] = 0.0h }; float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + // thread indices inside the simdgroup + const short tx = tiisg%4; + const short ty = tiisg/4; + // assume K and V are same shape const short ne22 = ne12; const short ne23 = ne13; - // broadcast + // broadcast k const short rk2 = ne02/ne12; const short rk3 = ne03/ne13; - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - // k indices const short ik2 = iq2/rk2; const short ik3 = iq3/rk3; - // v indices + // broadcast v + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + const short iv2 = iq2/rv2; const short iv3 = iq3/rv3; @@ -2885,15 +2888,12 @@ kernel void kernel_flash_attn_ext( } } else { for (short ii = 0; ii < D16; ii += 4) { - const short i = tiisg%4; - const short j = tiisg/4; - - device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + j)*nb11 + ik2*nb12 + ik3*nb13)); + device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); if (D16%4 == 0) { half4x4 tmp; - dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp); - skv4[4*j + i] = tmp; + dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp); + skv4[4*ty + tx] = tmp; simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2908,10 +2908,10 @@ kernel void kernel_flash_attn_ext( simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); } } else { - if (ii + i < D16) { + if (ii + tx < D16) { half4x4 tmp; - dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp); - skv4[4*j + i] = tmp; + dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp); + skv4[4*ty + tx] = tmp; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -3006,15 +3006,12 @@ kernel void kernel_flash_attn_ext( } } else { for (short ii = 0; ii < D16; ii += 4) { - const short i = tiisg%4; - const short j = tiisg/4; - - device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + j)*nb21 + iv2*nb22 + iv3*nb23)); + device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); if (D16%4 == 0) { half4x4 tmp; - dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp); - skv4[4*j + i] = tmp; + dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp); + skv4[4*ty + tx] = tmp; simdgroup_barrier(mem_flags::mem_threadgroup); @@ -3029,10 +3026,10 @@ kernel void kernel_flash_attn_ext( simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); } } else { - if (ii + i < D16) { + if (ii + tx < D16) { half4x4 tmp; - dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp); - skv4[4*j + i] = tmp; + dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp); + skv4[4*ty + tx] = tmp; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -3187,6 +3184,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +// NOTE: can use half instead of float precision for some extra perf // D - head size, Q - queries per threadgroup, C - cache items per threadgroup template kernel void kernel_flash_attn_ext_vec( @@ -3239,26 +3237,15 @@ kernel void kernel_flash_attn_ext_vec( const short T = D + 2*nsg*SH; // shared memory size per query in (half) - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4x4 lo[D16/NW4]; + float4x4 lo[D16/NW4]; // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); @@ -3273,7 +3260,7 @@ kernel void kernel_flash_attn_ext_vec( // zero out lo for (short i = 0; i < D16/NW4; i += NW4) { - lo[i] = half4x4(0.0h); + lo[i] = float4x4(0.0h); } // zero out shared memory SH @@ -3284,42 +3271,53 @@ kernel void kernel_flash_attn_ext_vec( threadgroup_barrier(mem_flags::mem_threadgroup); { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; + float S = 0.0h; + float M = -FLT_MAX/2; + + // thread indices inside the simdgroup + const short tx = tiisg%8; + const short ty = tiisg/8; // assume K and V are same shape const short ne22 = ne12; const short ne23 = ne13; - // broadcast + // broadcast k const short rk2 = ne02/ne12; const short rk3 = ne03/ne13; + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // broadcast v const short rv2 = ne02/ne22; const short rv3 = ne03/ne23; - // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; - - // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; // load the queries from shared memory into local memory float4x4 mq[D16/NW4]; for (short ii = 0; ii < D16; ii += NW4) { - short i = ii + tiisg%8; - mq[ii/NW4][0] = (float4) sq4[4*i + 0]; - mq[ii/NW4][1] = (float4) sq4[4*i + 1]; - mq[ii/NW4][2] = (float4) sq4[4*i + 2]; - mq[ii/NW4][3] = (float4) sq4[4*i + 3]; + mq[ii/NW4] = (float4x4) sq44[ii + tx]; } // pointer to the mask device const half * mp = (device const half *) (mask + iq1*nb31); + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { @@ -3331,18 +3329,16 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { // each simdgroup processes 1 query and 4 keys - const short j = tiisg/8; -#pragma unroll for (short cc = 0; cc < C/4; ++cc) { float mqk = 0.0; - device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13)); + device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); - float4x4 mk; #pragma unroll for (short ii = 0; ii < D16; ii += NW4) { - const short i = ii + tiisg%8; // 0..7 + const short i = ii + tx; + float4x4 mk; dequantize_func(pk + i/nl, i%nl, mk); mqk += @@ -3364,16 +3360,16 @@ kernel void kernel_flash_attn_ext_vec( mqk += simd_shuffle_down(mqk, 1); // mqk = mqk*scale + mask*slope - if (tiisg%8 == 0) { + if (tx == 0) { mqk *= scale; if (logit_softcap != 0.0f) { mqk = logit_softcap*precise::tanh(mqk); } - mqk += (mask != q) ? ((float) mp[ic + 4*cc + j])*slope : (float) 0.0f; + mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f; - ss[4*cc + j] = mqk; + ss[4*cc + ty] = mqk; } } } @@ -3408,20 +3404,20 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - const short j = tiisg/8; #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + j)*nb21 + iv2*nb22 + iv3*nb23)); + device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); + + const float4x4 lss(ss[4*cc + ty]); - float4x4 mv; - const float4x4 lss(ss[4*cc + j]); #pragma unroll for (short ii = 0; ii < D16; ii += NW4) { - const short i = ii + tiisg%8; + const short i = ii + tx; + float4x4 mv; dequantize_func(pv4 + i/nl, i%nl, mv); - lo[ii/NW4] += (half4x4)(mv*lss); + lo[ii/NW4] += mv*lss; } } } @@ -3458,14 +3454,8 @@ kernel void kernel_flash_attn_ext_vec( } // store results to shared memory - for (short ii = 0; ii < D16; ii += NW4) { - short i = ii + tiisg; - if (tiisg < 8) { - sr4[4*i + 0] = lo[ii/NW4][0]; - sr4[4*i + 1] = lo[ii/NW4][1]; - sr4[4*i + 2] = lo[ii/NW4][2]; - sr4[4*i + 3] = lo[ii/NW4][3]; - } + for (short i = tiisg; i < D16; i += NW4) { + sr44[i] = lo[i/NW4]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3492,24 +3482,22 @@ kernel void kernel_flash_attn_ext_vec( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + for (short i = tiisg; i < D16; i += NW) { + sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1; } } threadgroup_barrier(mem_flags::mem_threadgroup); } - device float4 * dst4 = (device float4 *) dst; + device float4x4 * dst44 = (device float4x4 *) dst; // final rescale with 1/S and store to global memory if (sgitg == 0) { const float S = ss[0]; - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + for (short i = tiisg; i < D16; i += NW) { + dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S; } } }