diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 25f4d7b82..8f7515617 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3297,7 +3297,7 @@ static void ggml_metal_encode_node( // ne00*(nsg) // each simdgroup has a full f16 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 20b104611..964d2a877 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2844,7 +2844,7 @@ kernel void kernel_flash_attn_ext( const short D8 = D/8; const short D16 = D/16; const short NW = N_SIMDWIDTH; - const short SH = (2*C + Q); // shared memory per simdgroup in (half) + const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = D + 2*TS; // shared memory size per query in (half) @@ -3353,16 +3353,17 @@ kernel void kernel_flash_attn_ext_vec( const short D16 = D/16; const short NW = N_SIMDWIDTH; const short NW4 = NW/4; - const short SH = C; // shared memory per simdgroup in (half) + const short SH = 2*C; // shared memory per simdgroup const short T = D + 2*nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4 - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4 - threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in half4 - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t + threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shared + 2*sgitg*SH + SH + Q*D); // scratch buffer for mask + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o4x4_t lo[D16/NW4]; @@ -3412,8 +3413,10 @@ kernel void kernel_flash_attn_ext_vec( mq[ii/NW4] = sq4x4[ii + tx]; } + const bool has_mask = mask != q; + // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); + device const half * pm = (device const half *) (mask + iq1*nb31); half slope = 1.0f; @@ -3435,6 +3438,10 @@ kernel void kernel_flash_attn_ext_vec( break; } + if (has_mask) { + sm[tiisg] = pm[ic + tiisg]; + } + // Q*K^T { // each simdgroup processes 1 query and 4 keys @@ -3476,7 +3483,7 @@ kernel void kernel_flash_attn_ext_vec( mqk = logit_softcap*precise::tanh(mqk); } - mqk += (s_t) ((mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f); + mqk += sm[4*cc + ty]*slope; ss[4*cc + ty] = mqk; }