metal : initial FA vec kernel

This commit is contained in:
Georgi Gerganov 2024-04-05 17:47:01 +03:00
parent f8d709f01a
commit b57af0c9dd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 82 additions and 125 deletions

View file

@ -2603,7 +2603,7 @@ static enum ggml_status ggml_metal_graph_compute(
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
int64_t nsg = 1;
while (nsg <= nsgt) {

View file

@ -2459,7 +2459,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f
#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_vec_f16(
device const char * q,
device const char * k,
@ -2499,12 +2499,12 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iq3 = tgpig[2];
const short iq2 = tgpig[1];
const short iq1 = tgpig[0]*Q;
const short iq1 = tgpig[0];
const short D4 = D/4;
const short D8 = D/8;
const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half)
const short SH = (C + 1); // shared memory per simdgroup in (half)
const short T = D + nsg*SH; // shared memory size per query in (half)
const short T4 = T/4; // shared memory size per query in (half4)
@ -2513,43 +2513,37 @@ kernel void kernel_flash_attn_ext_vec_f16(
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half4 lo[Q][D4/NW];
half4 lo[D4/NW];
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
sq4[j*T4 + i] = (half4) q4[i];
if (iq1 < ne01) {
sq4[i] = (half4) q4[i];
} else {
sq4[j*T4 + i] = 0.0h;
}
sq4[i] = 0.0h;
}
}
// zero out lo
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < D4; i += NW) {
lo[j][i/NW] = 0.0h;
}
lo[i/NW] = 0.0h;
}
// zero out shared memory SH
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH/4; i += NW) {
ss4[j*T4 + i] = 0.0h;
}
ss4[i] = 0.0h;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
half S[Q] = { [0 ... Q-1] = 0.0h };
half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF };
half S = { 0.0h };
half M = { -HALF_MAX_HALF };
// assume K and V are same shape
const short ne22 = ne12;
@ -2575,21 +2569,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
half4 mq[Q][D4];
half4 mq[D4];
for (short j = 0; j < Q; ++j) {
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
mq[j][i] = sq4[j*T4 + i];
}
mq[i] = sq4[i];
}
// pointer to the mask
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
// prepare diagonal scale matrix
half mscale(scale);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@ -2600,8 +2589,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
// Q*K^T
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
half4 mqk[Q] = { [0 ... Q-1] = 0.0h };
half4 mqk = { 0.0h };
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -2615,100 +2605,81 @@ kernel void kernel_flash_attn_ext_vec_f16(
mk[2] = pk4[i + 2*(nb11/8)];
mk[3] = pk4[i + 3*(nb11/8)];
for (short j = 0; j < Q; ++j) {
mqk[j] += mq[j][i] * mk;
}
mqk += mq[i] * mk;
}
// reduce the results from the threads in the simdgroup
for (short j = 0; j < Q; ++j) {
mqk[j] += simd_shuffle_down(mqk[j], 16);
mqk[j] += simd_shuffle_down(mqk[j], 8);
mqk[j] += simd_shuffle_down(mqk[j], 4);
mqk[j] += simd_shuffle_down(mqk[j], 2);
mqk[j] += simd_shuffle_down(mqk[j], 1);
}
mqk += simd_shuffle_down(mqk, 16);
mqk += simd_shuffle_down(mqk, 8);
mqk += simd_shuffle_down(mqk, 4);
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask
if (tiisg == 0) {
for (short j = 0; j < Q; ++j) {
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
mqk[j] = mqk[j]*mscale + mm;
half4 mm = mp4[ic/4 + cc];
mqk = mqk*scale + mm;
ss4[j*T4 + cc] = mqk[j];
}
ss4[cc] = mqk;
}
}
}
// online softmax
half ms[Q];
for (short j = 0; j < Q; ++j) {
{
const short p = tiisg;
const half m = M[j];
const half s = ss[j*T + p];
const half m = M;
const half s = ss[p];
M[j] = simd_max(max(M[j], s));
M = simd_max(max(M, s));
ms[j] = exp(m - M[j]);
const half vs = exp(s - M[j]);
const half ms = exp(m - M);
const half vs = exp(s - M);
S[j] = S[j]*ms[j] + simd_sum(vs);
S = S*ms + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg];
}
ss[p] = vs;
// O = diag(ms)*O
for (short j = 0; j < Q; ++j) {
half mm(ss[j*T + C + j]);
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
lo[j][i/NW] = lo[j][i/NW]*mm;
lo[i/NW] *= ms;
}
}
// O = O + (Q*K^T)*V
{
#pragma unroll
for (short cc = 0; cc < C; ++cc) {
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
for (short cc = 0; cc < C/4; ++cc) {
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
for (short j = 0; j < Q; ++j) {
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
}
}
const short i = ii + tiisg;
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) {
if (tiisg == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
ss[0] = S;
ss[1] = M;
}
}
// store results to shared memory
for (short j = 0; j < Q; ++j) {
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
sr4[i] = lo[j][ii/NW];
}
sr4[i] = lo[ii/NW];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -2716,13 +2687,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
// parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) {
if (tiisg == 0) {
for (short j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + r*SH + 0];
const half S0 = ss[ 0];
const half S1 = ss[r*SH + 0];
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + r*SH + 1];
const half M0 = ss[ 1];
const half M1 = ss[r*SH + 1];
const half M = max(M0, M1);
@ -2731,28 +2700,17 @@ kernel void kernel_flash_attn_ext_vec_f16(
const half S = S0*ms0 + S1*ms1;
ss[j*T + 0] = S;
ss[j*T + 1] = M;
ss[j*T + C + j ] = ms0;
ss[j*T + C + j + r*SH] = ms1;
if (tiisg == 0) {
ss[0] = S;
ss[1] = M;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg < r) {
for (short j = 0; j < Q; ++j) {
const half ms0 = ss[j*T + C + j];
const half ms1 = ss[j*T + C + j + r*SH];
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short i = tiisg; i < D4; i += NW) {
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@ -2761,18 +2719,17 @@ kernel void kernel_flash_attn_ext_vec_f16(
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
const half S = ss[0];
for (short i = tiisg; i < D4; i += NW) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S;
}
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
}
}
}
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>;
kernel void kernel_cpy_f16_f16(
device const half * src0,