metal : clean-up kernel code

This commit is contained in:
Georgi Gerganov 2024-04-19 15:52:49 +03:00
parent 97eaece7d6
commit 1a88565b44
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -2121,7 +2121,7 @@ typedef void (flash_attn_ext_f16_t)(
ushort sgitg[[simdgroup_index_in_threadgroup]]); ushort sgitg[[simdgroup_index_in_threadgroup]]);
// ref: https://arxiv.org/pdf/2307.08691.pdf // ref: https://arxiv.org/pdf/2307.08691.pdf
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_f16( kernel void kernel_flash_attn_ext_f16(
device const char * q, device const char * q,
device const char * k, device const char * k,
@ -2178,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[Q8][D8]; simdgroup_half8x8 lo[D8];
// load heads from Q to shared memory // load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) { for (short j = sgitg; j < Q; j += nsg) {
@ -2194,10 +2194,8 @@ kernel void kernel_flash_attn_ext_f16(
} }
// zero out lo // zero out lo
for (short j = 0; j < Q8; ++j) { for (short i = 0; i < D8; ++i) {
for (short i = 0; i < D8; ++i) { lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
}
} }
// zero out shared memory SH // zero out shared memory SH
@ -2229,20 +2227,18 @@ kernel void kernel_flash_attn_ext_f16(
const short rv3 = ne03/ne23; const short rv3 = ne03/ne23;
// k indices // k indices
const short ik2 = iq2 / rk2; const short ik2 = iq2/rk2;
const short ik3 = iq3 / rk3; const short ik3 = iq3/rk3;
// v indices // v indices
const short iv2 = iq2 / rv2; const short iv2 = iq2/rv2;
const short iv3 = iq3 / rv3; const short iv3 = iq3/rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
simdgroup_half8x8 mq[Q8][D8]; simdgroup_half8x8 mq[D8];
for (short j = 0; j < Q8; ++j) { for (short i = 0; i < D8; ++i) {
for (short i = 0; i < D8; ++i) { simdgroup_load(mq[i], sq + i*8, T);
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
}
} }
// pointer to the mask // pointer to the mask
@ -2262,10 +2258,7 @@ kernel void kernel_flash_attn_ext_f16(
// Q*K^T // Q*K^T
{ {
for (short cc = 0; cc < C/8; ++cc) { for (short cc = 0; cc < C/8; ++cc) {
simdgroup_float8x8 mqk[Q8]; simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
for (short j = 0; j < Q8; ++j) {
mqk[j] = make_filled_simdgroup_matrix<float, 8>(0.h);
}
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -2273,19 +2266,15 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_half8x8 mk; simdgroup_half8x8 mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
for (short j = 0; j < Q8; ++j) { simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
}
} }
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mm;
simdgroup_half8x8 mm; simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false); simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
}
} }
} }
@ -2293,7 +2282,7 @@ kernel void kernel_flash_attn_ext_f16(
float smax = -INFINITY; float smax = -INFINITY;
// online softmax // online softmax
if (C == 32) { {
float ms[Q]; float ms[Q];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
@ -2314,45 +2303,6 @@ kernel void kernel_flash_attn_ext_f16(
ss[j*TF + p] = vs; ss[j*TF + p] = vs;
} }
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*TF + C + tiisg] = ms[tiisg];
}
} else {
float ms[Q];
for (short j = 0; j < Q; ++j) {
const float m = M[j];
for (short p = tiisg; p < C; p += NW) {
const float s = ss[j*TF + p];
smax = max(smax, s);
M[j] = max(M[j], s);
}
smax = simd_max(smax);
M[j] = simd_max(M[j]);
ms[j] = exp(m - M[j]);
// local sum
float ls = 0.0h;
for (short p = tiisg; p < C; p += NW) {
const float s = ss[j*TF + p];
const float vs = exp(s - M[j]);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*TF + p] = vs;
}
S[j] = S[j]*ms[j] + simd_sum(ls);
}
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) { if (tiisg < Q) {
ss[tiisg*TF + C + tiisg] = ms[tiisg]; ss[tiisg*TF + C + tiisg] = ms[tiisg];
@ -2365,12 +2315,12 @@ kernel void kernel_flash_attn_ext_f16(
} }
// O = diag(ms)*O // O = diag(ms)*O
for (short j = 0; j < Q8; ++j) { {
simdgroup_float8x8 mm; simdgroup_float8x8 mm;
simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false); simdgroup_load(mm, ss + C, TF, 0, false);
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_multiply(lo[j][i], mm, lo[j][i]); simdgroup_multiply(lo[i], mm, lo[i]);
} }
} }
@ -2383,12 +2333,10 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_half8x8 mk; simdgroup_half8x8 mk;
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
for (short j = 0; j < Q8; ++j) { simdgroup_float8x8 mv;
simdgroup_float8x8 mv; simdgroup_load(mv, ss + 8*cc, TF, 0, false);
simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false);
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
}
} }
} }
} }
@ -2412,10 +2360,8 @@ kernel void kernel_flash_attn_ext_f16(
// each simdgroup stores its output to shared memory, reusing sq // each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) { if (sgitg == sg) {
for (short j = 0; j < Q8; ++j) { for (short i = 0; i < D8; ++i) {
for (short i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false);
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
}
} }
} }
@ -2447,19 +2393,19 @@ kernel void kernel_flash_attn_ext_f16(
} }
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short j = 0; j < Q8; ++j) { {
simdgroup_half8x8 t; simdgroup_half8x8 t;
simdgroup_float8x8 ms0; simdgroup_float8x8 ms0;
simdgroup_float8x8 ms1; simdgroup_float8x8 ms1;
simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false); simdgroup_load(ms0, ss + C, TF, 0, false);
simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false); simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_load (t, sq + i*8, T, 0, false);
simdgroup_multiply(t, ms1, t); simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
} }
} }
} }
@ -2467,10 +2413,8 @@ kernel void kernel_flash_attn_ext_f16(
// store result to shared memory (reuse sq) // store result to shared memory (reuse sq)
if (sgitg == 0) { if (sgitg == 0) {
for (short j = 0; j < Q8; ++j) { for (short i = 0; i < D8; ++i) {
for (short i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false);
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
}
} }
} }
@ -2488,14 +2432,14 @@ kernel void kernel_flash_attn_ext_f16(
} }
} }
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_vec_f16( kernel void kernel_flash_attn_ext_vec_f16(
device const char * q, device const char * q,
device const char * k, device const char * k,
@ -2539,7 +2483,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short D4 = D/4; const short D4 = D/4;
const short NW = N_SIMDWIDTH; const short NW = N_SIMDWIDTH;
const short SH = (C + 1); // shared memory per simdgroup in (half) const short SH = (C + Q); // shared memory per simdgroup in (half)
const short T = D + 2*nsg*SH; // shared memory size per query in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half)
@ -2763,8 +2707,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
} }
} }
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,