diff --git a/ggml-metal.m b/ggml-metal.m index edb18aec8..93b499a12 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2253,12 +2253,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nsg = 8; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 8; //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(ne00 + nsg*(2*ncpsg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index be38c17f0..3d4719ea0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2046,15 +2046,16 @@ kernel void kernel_flash_attn_ext_f16( const int64_t L4 = (D4 + N4 - 1)/N4; const int64_t D8 = D/8; - const int64_t T = D + nsg*(2*C); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 - threadgroup half * pq = (threadgroup half *) (shared + 0*D); - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*C) + 1*D); - //threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D); + threadgroup half * pq = (threadgroup half *) (shared + 0*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); + threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); - half4 ps4[Q][L4]; + half4 ls4[Q][L4]; // load heads from Q to shared memory for (int64_t i = 0; i < L4; ++i) { @@ -2067,11 +2068,11 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t j = 0; j < Q; ++j) { - ps4[j][i] = 0.0h; + ls4[j][i] = 0.0h; } } - if (tiisg < 2) { + if (tiisg < 1) { for (int64_t j = 0; j < Q; ++j) { ss[j*T + tiisg] = 0.0h; } @@ -2127,15 +2128,15 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { - { - bool skip = true; - for (int64_t j = 0; j < Q; ++j) { - skip = skip && (mp[j][iic] == -INFINITY); - } - if (skip) { - continue; - } - } + //{ + // bool skip = true; + // for (int64_t j = 0; j < Q; ++j) { + // skip = skip && (mp[j][iic] == -INFINITY); + // } + // if (skip) { + // continue; + // } + //} { simdgroup_half8x8 mk; @@ -2152,33 +2153,13 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_store(mqk, ss, T, 0, false); } - //if (tiisg < Q) { - // const int64_t j = tiisg; - - // for (int p = 0; p < C; ++p) { - // const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]); - - // const half m = M; - - // M = max(M, s); - - // const half ms = m == -INFINITY ? 0.0h : exp(m - M); - // const half vs = s == -INFINITY ? 0.0h : exp(s - M); - - // S = S*ms + vs; - - // ss[j*T + 0 + p] = ms; - // ss[j*T + C + p] = vs; - // } - //} - // not sure why this barrier is needed simdgroup_barrier(mem_flags::mem_none); for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg % C; - const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]); + const half s = ss[j*T + p]*scale + (mp[j][iic + p]); half m = M[j]; @@ -2187,43 +2168,42 @@ kernel void kernel_flash_attn_ext_f16( const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms + 0.25h*simd_sum(vs); + S[j] = S[j]*ms + 0.25h*simd_sum(vs); // 4*8 = 32 for (int64_t i = 0; i < L4; ++i) { - ps4[j][i] *= ms; + ls4[j][i] *= ms; } if (tiisg < C) { - ss[j*T + C + p] = vs; + ss[j*T + p] = vs; } } - for (int64_t p = 0; p < C; ++p) { - const int64_t ic = iic + p; - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + { + simdgroup_half8x8 mv; + simdgroup_half8x8 mp; + simdgroup_half8x8 mqkv; - for (int64_t j = 0; j < Q; ++j) { - const half vs = ss[j*T + C + p]; + device const half * pv = (device const half *) ((device const char *) v + (iic*nb21 + iv2*nb22 + iv3*nb23)); - for (int64_t i = 0; i < L4; ++i) { - ps4[j][i] += pv4[N4*i + tiisg]*vs; - } + // load mp + simdgroup_load(mp, ss, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load (mv, pv + i*8, nb21/2, 0, false); + simdgroup_multiply(mqkv, mp, mv); + simdgroup_store (mqkv, ps + i*8, T, 0, false); } } - //for (int p = 0; p < C; ++p) { - // const int64_t ic = iic + p; - // device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + // not sure why this barrier is needed too + threadgroup_barrier(mem_flags::mem_none); - // for (int64_t j = 0; j < Q; ++j) { - // const half ms = ss[j*T + 0 + p]; - // const half vs = ss[j*T + C + p]; - - // for (int64_t i = 0; i < L4; ++i) { - // ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs; - // } - // } - //} + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = 0; i < L4; ++i) { + ls4[j][i] += ps4[j*T4 + N4*i + tiisg]; + } + } } for (int64_t j = 0; j < Q; ++j) { @@ -2246,7 +2226,7 @@ kernel void kernel_flash_attn_ext_f16( // store heads to shared memory - reuse pq4 for (int64_t j = 0; j < Q; ++j) { for (int64_t i = 0; i < L4; ++i) { - pq4[j*T4 + N4*i + tiisg] = ps4[j][i]; + pq4[j*T4 + N4*i + tiisg] = ls4[j][i]; } } } @@ -2255,11 +2235,11 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(2*C) + 0]; + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(2*C) + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; M = max(M0, M1); @@ -2274,7 +2254,7 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t i = 0; i < L4; ++i) { - ps4[j][i] = ps4[j][i]*ms0 + pq4[j*T4 + N4*i + tiisg]*ms1; + ls4[j][i] = ls4[j][i]*ms0 + pq4[j*T4 + N4*i + tiisg]*ms1; } } } @@ -2286,7 +2266,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q; ++j) { const half S = ss[j*T + 0]; for (int64_t i = 0; i < L4; ++i) { - ps4[j][i] = ps4[j][i]/S; + ls4[j][i] = ls4[j][i]/S; } } } @@ -2298,7 +2278,7 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j][i]; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ls4[j][i]; } } }