diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 99587fb98..741a4aef0 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3277,7 +3277,7 @@ static void ggml_metal_encode_node( // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2e447c928..8e540e943 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2836,8 +2836,8 @@ kernel void kernel_flash_attn_ext( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -2848,7 +2848,7 @@ kernel void kernel_flash_attn_ext( const short D8 = D/8; const short D16 = D/16; const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + const short SH = (2*C + Q); // shared memory per simdgroup in (half) const short SF = sizeof(s_t)/sizeof(half); @@ -2933,9 +2933,6 @@ kernel void kernel_flash_attn_ext( simdgroup_load(mq[i], sq + i*8, T); } - // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); - const bool has_mask = mask != q; float slope = 1.0f; @@ -2958,6 +2955,26 @@ kernel void kernel_flash_attn_ext( break; } + if (has_mask) { + // used to detect blocks full of -INF + half smax = -INFINITY; + + for (short j = 0; j < Q; ++j) { + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); + + const half m = pm[ic + tiisg]; + + ss[j*TS + C + tiisg] = m; + smax = max(smax, m); + } + + smax = simd_max(smax); + + if (smax == -INFINITY) { + continue; + } + } + // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { @@ -3033,9 +3050,6 @@ kernel void kernel_flash_attn_ext( } } - // used to detect blocks full of -INF - float smax = -INFINITY; - // online softmax { float ms[Q]; @@ -3052,10 +3066,10 @@ kernel void kernel_flash_attn_ext( if (has_mask) { // mqk = mqk + mask*slope - s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30 + //s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30 + s += slope*ss[j*TS + C + tiisg]; } - smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); ms[j] = exp(m - M[j]); @@ -3069,19 +3083,14 @@ kernel void kernel_flash_attn_ext( // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*TS + C + tiisg] = ms[tiisg]; + ss[tiisg*TS + 2*C + tiisg] = ms[tiisg]; } } - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - // O = diag(ms)*O { s8x8_t mm; - simdgroup_load(mm, ss + C, TS, 0, false); + simdgroup_load(mm, ss + 2*C, TS, 0, false); #pragma unroll for (short i = 0; i < D8; ++i) { @@ -3199,8 +3208,8 @@ kernel void kernel_flash_attn_ext( ss[j*TS + 0] = S; ss[j*TS + 1] = M; - ss[j*TS + C + j ] = ms0; - ss[j*TS + C + j + sg*SH] = ms1; + ss[j*TS + 2*C + j ] = ms0; + ss[j*TS + 2*C + j + sg*SH] = ms1; } } @@ -3209,8 +3218,8 @@ kernel void kernel_flash_attn_ext( s8x8_t ms0; s8x8_t ms1; - simdgroup_load(ms0, ss + C, TS, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false); + simdgroup_load(ms0, ss + 2*C, TS, 0, false); + simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); for (short i = 0; i < D8; ++i) { o8x8_t t;