metal : separate scale and mask from QKT in FA kernel

This commit is contained in:
Georgi Gerganov 2024-08-26 16:04:13 +03:00
parent 7a3df798fc
commit e65fc9b8b2
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2140,24 +2140,21 @@ kernel void kernel_flash_attn_ext_f16(
} }
simdgroup_store(mqk, ss + 8*cc, TF, 0, false); simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
}
}
const short tx = tiisg%4; // scale and apply the mask (assume C = 32)
const short ty = tiisg/4; for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
// mqk = mqk*scale // mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 0] *= scale; ss[j*TF + tiisg] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
if (logit_softcap != 0.0f) { if (logit_softcap != 0.0f) {
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); ss[j*TF + tiisg] = logit_softcap*precise::tanh(ss[j*TF + tiisg]);
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
} }
if (mask != q) { if (mask != q) {
// mqk = mqk + mask*slope // mqk = mqk + mask*slope
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; ss[j*TF + tiisg] += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
}
} }
} }
@ -2169,10 +2166,8 @@ kernel void kernel_flash_attn_ext_f16(
float ms[Q]; float ms[Q];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const short p = tiisg;
const float m = M[j]; const float m = M[j];
const float s = ss[j*TF + p]; const float s = ss[j*TF + tiisg];
smax = simd_max(max(smax, s)); smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s)); M[j] = simd_max(max(M[j], s));
@ -2183,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16(
S[j] = S[j]*ms[j] + simd_sum(vs); S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns) // the P matrix from the paper (Q rows, C columns)
ss[j*TF + p] = vs; ss[j*TF + tiisg] = vs;
} }
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output