From e65fc9b8b25021dfb6c0e051df0d3cb71a88d77b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 Aug 2024 16:04:13 +0300 Subject: [PATCH] metal : separate scale and mask from QKT in FA kernel --- ggml/src/ggml-metal.metal | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index aba0b9a03..1b32e4384 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2140,24 +2140,21 @@ kernel void kernel_flash_attn_ext_f16( } simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } - const short tx = tiisg%4; - const short ty = tiisg/4; + // scale and apply the mask (assume C = 32) + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + // mqk = mqk*scale + ss[j*TF + tiisg] *= scale; - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; + if (logit_softcap != 0.0f) { + ss[j*TF + tiisg] = logit_softcap*precise::tanh(ss[j*TF + tiisg]); + } - if (logit_softcap != 0.0f) { - ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); - ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); - } - - if (mask != q) { - // mqk = mqk + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; - } + if (mask != q) { + // mqk = mqk + mask*slope + ss[j*TF + tiisg] += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; } } @@ -2169,10 +2166,8 @@ kernel void kernel_flash_attn_ext_f16( float ms[Q]; for (short j = 0; j < Q; ++j) { - const short p = tiisg; - const float m = M[j]; - const float s = ss[j*TF + p]; + const float s = ss[j*TF + tiisg]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); @@ -2183,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; + ss[j*TF + tiisg] = vs; } // create a QxQ diagonal matrix for rescaling the output