metal : faster inner loop for C == 32
This commit is contained in:
parent
c6c1132e5e
commit
5fcb9c1c5a
1 changed files with 44 additions and 19 deletions
|
@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
const int64_t T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
simdgroup_half8x8 lo[Q8][D8];
|
||||
|
@ -2164,6 +2164,30 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
half smax = -INFINITY;
|
||||
|
||||
// online softmax
|
||||
if (C == 32) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = tiisg;
|
||||
|
||||
const half m = M[j];
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
smax = simd_max(max(smax, s));
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms + simd_sum(vs);
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (p == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
}
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
|
||||
|
@ -2194,6 +2218,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
ss[j*T + p] = vs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// skip -INF blocks
|
||||
if (smax == -INFINITY) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue