metal : faster inner loop for C == 32

This commit is contained in:
Georgi Gerganov 2024-01-29 19:46:22 +02:00
parent c6c1132e5e
commit 5fcb9c1c5a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16(
const int64_t T4 = T/4; // shared memory size per query in (half4)
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[Q8][D8];
@ -2164,35 +2164,60 @@ kernel void kernel_flash_attn_ext_f16(
half smax = -INFINITY;
// online softmax
for (int64_t j = 0; j < Q; ++j) {
const half m = M[j];
if (C == 32) {
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = tiisg;
for (int64_t p = tiisg; p < C; p += NW) {
const half m = M[j];
const half s = ss[j*T + p];
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
}
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
S[j] = S[j]*ms;
// create a QxQ diagonal matrix for rescaling the output
if (tiisg == j) {
ss[j*T + C + j] = ms;
}
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j] + simd_sum(vs);
S[j] = S[j]*ms + simd_sum(vs);
// create a QxQ diagonal matrix for rescaling the output
if (p == j) {
ss[j*T + C + j] = ms;
}
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
} else {
for (int64_t j = 0; j < Q; ++j) {
const half m = M[j];
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
}
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
S[j] = S[j]*ms;
// create a QxQ diagonal matrix for rescaling the output
if (tiisg == j) {
ss[j*T + C + j] = ms;
}
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
}
}
// skip -INF blocks