metal : faster inner loop for C == 32
This commit is contained in:
parent
c6c1132e5e
commit
5fcb9c1c5a
1 changed files with 44 additions and 19 deletions
|
@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
const int64_t T4 = T/4; // shared memory size per query in (half4)
|
const int64_t T4 = T/4; // shared memory size per query in (half4)
|
||||||
|
|
||||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix
|
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
|
||||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
simdgroup_half8x8 lo[Q8][D8];
|
simdgroup_half8x8 lo[Q8][D8];
|
||||||
|
@ -2164,6 +2164,30 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
half smax = -INFINITY;
|
half smax = -INFINITY;
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
|
if (C == 32) {
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const int64_t p = tiisg;
|
||||||
|
|
||||||
|
const half m = M[j];
|
||||||
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
|
smax = simd_max(max(smax, s));
|
||||||
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
|
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms + simd_sum(vs);
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (p == j) {
|
||||||
|
ss[j*T + C + j] = ms;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss[j*T + p] = vs;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
|
@ -2194,6 +2218,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
if (smax == -INFINITY) {
|
if (smax == -INFINITY) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue