diff --git a/ggml-metal.m b/ggml-metal.m index d942b673f..0f405b112 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2616,7 +2616,6 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - //const int64_t nsg = 1; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index d7ce10274..282ec3eb6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2529,7 +2529,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // zero out lo for (short j = 0; j < Q; ++j) { - for (short i = 0; i < D4; ++i) { + for (short i = tiisg; i < D4; i += NW) { lo[j][i] = 0.0h; } } @@ -2648,10 +2648,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // used to detect blocks full of -INF - half smax = -INFINITY; + //threadgroup_barrier(mem_flags::mem_threadgroup); // online softmax if (C == 32) { @@ -2663,7 +2660,6 @@ kernel void kernel_flash_attn_ext_vec_f16( const half m = M[j]; const half s = ss[j*T + p]; - smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); @@ -2688,11 +2684,9 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - smax = max(smax, s); M[j] = max(M[j], s); } - smax = simd_max(smax); M[j] = simd_max(M[j]); ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); @@ -2720,12 +2714,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); + //threadgroup_barrier(mem_flags::mem_threadgroup); // O = diag(ms)*O for (short j = 0; j < Q; ++j) { @@ -2742,26 +2731,12 @@ kernel void kernel_flash_attn_ext_vec_f16( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C; ++cc) { - //device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); - //for (short i = 0; i < D8; ++i) { - // simdgroup_half8x8 mk; - // simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - - // for (short j = 0; j < Q8; ++j) { - // simdgroup_half8x8 mv; - // simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); - - // simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); - // } - //} - + half vsum[Q]; for (short i = tiisg; i < D4; i += NW) { - half4 mk = pv4[i]; - for (short j = 0; j < Q; ++j) { - lo[j][i] += mk*ss[j*T + cc]; + lo[j][i] += pv4[i]*ss[j*T + cc]; } } }