metal : keep data in local memory
This commit is contained in:
parent
e865686c21
commit
ff23e8e9f0
1 changed files with 12 additions and 16 deletions
|
@ -2143,21 +2143,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
}
|
||||
}
|
||||
|
||||
// scale and apply the mask (assume C = 32)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
// mqk = mqk*scale
|
||||
ss[j*TF + tiisg] *= scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
ss[j*TF + tiisg] = logit_softcap*precise::tanh(ss[j*TF + tiisg]);
|
||||
}
|
||||
|
||||
if (mask != q) {
|
||||
// mqk = mqk + mask*slope
|
||||
ss[j*TF + tiisg] += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
// used to detect blocks full of -INF
|
||||
float smax = -INFINITY;
|
||||
|
||||
|
@ -2167,7 +2152,18 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const float m = M[j];
|
||||
const float s = ss[j*TF + tiisg];
|
||||
|
||||
// scale and apply the logitcap / mask
|
||||
float s = ss[j*TF + tiisg]*scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s = logit_softcap*precise::tanh(s);
|
||||
}
|
||||
|
||||
if (mask != q) {
|
||||
// mqk = mqk + mask*slope
|
||||
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
||||
}
|
||||
|
||||
smax = simd_max(max(smax, s));
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue