Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Georgi Gerganov
a95225cdfd
metal : another fix for the fa kernel 2024-08-26 15:08:38 +03:00

View file

@ -2144,19 +2144,21 @@ kernel void kernel_flash_attn_ext_f16(
const short tx = tiisg%4; const short tx = tiisg%4;
const short ty = tiisg/4; const short ty = tiisg/4;
// mqk = mqk*scale if (iq1 + ty < ne01) {
ss[8*cc + ty*TF + 2*tx + 0] *= scale; // mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 1] *= scale; ss[8*cc + ty*TF + 2*tx + 0] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
if (logit_softcap != 0.0f) { if (logit_softcap != 0.0f) {
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
} }
if (mask != q) { if (mask != q) {
// mqk = mqk + mask*slope // mqk = mqk + mask*slope
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
}
} }
} }
} }