metal : add tests, fix scaling, support C > 32

This commit is contained in:
Georgi Gerganov 2024-01-28 15:42:57 +02:00
parent 77f6976a87
commit ecc466a460
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 47 additions and 37 deletions

View file

@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16(
const int64_t D4 = D/4;
const int64_t D8 = D/8;
const int64_t NW = N_SIMDWIDTH;
const int64_t L4 = (D4 + NW - 1)/NW;
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
const int64_t T = D + nsg*SH; // shared memory size per query in (half)
@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16(
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[D8];
for (int64_t i = 0; i < L4; ++i) {
// load heads from Q to shared memory
for (int64_t j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
// load heads from Q to shared memory
for (int64_t j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
for (int64_t i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg];
sq4[j*T4 + i] = (half4) q4[i];
} else {
sq4[j*T4 + NW*i + tiisg] = 0.0h;
sq4[j*T4 + i] = 0.0h;
}
}
}
@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16(
}
// zero out shared memory SH
if (tiisg < C) {
for (int64_t j = 0; j < Q; ++j) {
ss[j*T + tiisg] = 0.0h;
if (tiisg < Q) {
ss[j*T + C + tiisg] = 0.0h;
}
for (int64_t j = 0; j < Q; ++j) {
for (int64_t i = tiisg; i < SH; i += NW) {
ss[j*T + i] = 0.0h;
}
}
@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16(
// online softmax
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = tiisg;
const half s = ss[j*T + p];
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
const half m = M[j];
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
S[j] = S[j]*ms + simd_sum(vs);
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
}
const half ms = exp(m - M[j]);
S[j] = S[j]*ms;
// create an 8x8 diagonal matrix for rescaling the output
if (p == j) {
if (tiisg == j) {
ss[j*T + C + j] = ms;
}
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = exp(s - M[j]);
S[j] = S[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
}
// skip -INF blocks
@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup_barrier(mem_flags::mem_threadgroup);
// each simdgroup stores its output to shared memory, reusing sq4
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false);
@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16(
}
}
// store result to shared memory (reuse sq4)
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false);
@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16(
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (int64_t i = 0; i < L4; ++i) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S;
for (int64_t i = tiisg; i < D4; i += NW) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
}
}
}