metal : optimize softmax for C > 32

This commit is contained in:
Georgi Gerganov 2024-02-01 20:16:32 +02:00
parent 41d136b602
commit 56e45a239e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 20 additions and 5 deletions

View file

@ -2217,29 +2217,35 @@ kernel void kernel_flash_attn_ext_f16(
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
smax = max(smax, s);
M[j] = max(M[j], s);
}
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
smax = simd_max(smax);
M[j] = simd_max(M[j]);
S[j] = S[j]*ms;
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (tiisg == j) {
ss[j*T + C + j] = ms;
}
// local sum
half ls = 0.0h;
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j] + simd_sum(vs);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
S[j] = S[j]*ms + simd_sum(ls);
}
}