diff --git a/ggml-metal.metal b/ggml-metal.metal index 282ec3eb6..533b9fef6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2581,7 +2581,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); // prepare diagonal scale matrix //simdgroup_half8x8 mscale(scale); @@ -2597,23 +2597,23 @@ kernel void kernel_flash_attn_ext_vec_f16( // Q*K^T { - for (short cc = 0; cc < C; ++cc) { - half mqk[Q]; + for (short cc = 0; cc < C/4; ++cc) { + half4 mqk[Q]; for (short j = 0; j < Q; ++j) { mqk[j] = 0.0h; } - //device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (short i = tiisg; i < D4; i += NW) { - //simdgroup_half8x8 mk; - //simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - half4 mk = pk4[i]; + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; for (short j = 0; j < Q; ++j) { - //simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); - mqk[j] += dot(mq[j][i], mk); + mqk[j] += mq[j][i] * mk; } } @@ -2633,85 +2633,40 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask if (tiisg == 0) { for (short j = 0; j < Q; ++j) { - //simdgroup_half8x8 mm; - //simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); - //simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - - //simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); - - half mm = mp[j*(nb31/sizeof(half)) + ic + cc]; + half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; mqk[j] = mqk[j]*mscale + mm; - ss[j*T + cc] = mqk[j]; + ss4[j*T4 + cc] = mqk[j]; } } } } - //threadgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_barrier(mem_flags::mem_threadgroup); // online softmax - if (C == 32) { - half ms[Q]; + half ms[Q]; - for (short j = 0; j < Q; ++j) { - const short p = tiisg; + for (short j = 0; j < Q; ++j) { + const short p = tiisg; - const half m = M[j]; - const half s = ss[j*T + p]; + const half m = M[j]; + const half s = ss[j*T + p]; - M[j] = simd_max(max(M[j], s)); + M[j] = simd_max(max(M[j], s)); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms[j] + simd_sum(vs); + S[j] = S[j]*ms[j] + simd_sum(vs); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; - } - } else { - half ms[Q]; - - for (short j = 0; j < Q; ++j) { - const half m = M[j]; - - for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; - - M[j] = max(M[j], s); - } - - M[j] = simd_max(M[j]); - - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - - // local sum - half ls = 0.0h; - - for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; - - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } - - S[j] = S[j]*ms[j] + simd_sum(ls); - } - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; - } + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; } //threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2733,7 +2688,6 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short cc = 0; cc < C; ++cc) { device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); - half vsum[Q]; for (short i = tiisg; i < D4; i += NW) { for (short j = 0; j < Q; ++j) { lo[j][i] += pv4[i]*ss[j*T + cc];