metal : switch to parallel reduce
This commit is contained in:
parent
5733b00e53
commit
e51778de5e
2 changed files with 119 additions and 93 deletions
16
ggml-metal.m
16
ggml-metal.m
|
@ -2615,13 +2615,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
//const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||||
|
const int64_t nsg = 8;
|
||||||
|
|
||||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
// require power of 2
|
||||||
|
//{
|
||||||
|
// int64_t nsgm = 1;
|
||||||
|
// while (nsgm < nsg) {
|
||||||
|
// nsgm *= 2;
|
||||||
|
// }
|
||||||
|
// GGML_ASSERT(nsg == nsgm);
|
||||||
|
//}
|
||||||
|
|
||||||
|
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||||
|
|
||||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
}
|
}
|
||||||
|
|
188
ggml-metal.metal
188
ggml-metal.metal
|
@ -2457,6 +2457,8 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
|
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
|
||||||
|
|
||||||
|
#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||||
|
|
||||||
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
kernel void kernel_flash_attn_ext_vec_f16(
|
kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
device const char * q,
|
device const char * q,
|
||||||
|
@ -2500,6 +2502,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
const short iq1 = tgpig[0]*Q;
|
const short iq1 = tgpig[0]*Q;
|
||||||
|
|
||||||
const short D4 = D/4;
|
const short D4 = D/4;
|
||||||
|
const short D8 = D/8;
|
||||||
const short NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
|
@ -2510,6 +2513,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
|
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
|
||||||
|
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
||||||
|
|
||||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
half4 lo[Q][D4];
|
half4 lo[Q][D4];
|
||||||
|
@ -2545,7 +2549,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
|
||||||
{
|
{
|
||||||
half S[Q] = { [0 ... Q-1] = 0.0h };
|
half S[Q] = { [0 ... Q-1] = 0.0h };
|
||||||
half M[Q] = { [0 ... Q-1] = -INFINITY };
|
half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF };
|
||||||
|
|
||||||
// assume K and V are same shape
|
// assume K and V are same shape
|
||||||
const short ne22 = ne12;
|
const short ne22 = ne12;
|
||||||
|
@ -2571,21 +2575,21 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
const short iv3 = iq3 / rv3;
|
const short iv3 = iq3 / rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
half4 mq[Q][D4];
|
simdgroup_half8x8 mq[Q][D8];
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
//simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
|
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
|
||||||
mq[j][i] = sq4[j*T4 + i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
//device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
||||||
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||||
|
|
||||||
// prepare diagonal scale matrix
|
// prepare diagonal scale matrix
|
||||||
//simdgroup_half8x8 mscale(scale);
|
simdgroup_half8x8 mscale(scale);
|
||||||
half mscale(scale);
|
//half mscale(scale);
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
|
@ -2595,54 +2599,82 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Q*K^T
|
||||||
|
//{
|
||||||
|
// for (short cc = 0; cc < C/4; ++cc) {
|
||||||
|
// half4 mqk[Q];
|
||||||
|
// for (short j = 0; j < Q; ++j) {
|
||||||
|
// mqk[j] = 0.0h;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
// for (short i = tiisg; i < D4; i += NW) {
|
||||||
|
// half4x4 mk;
|
||||||
|
// mk[0] = pk4[i + 0*(nb11/8)];
|
||||||
|
// mk[1] = pk4[i + 1*(nb11/8)];
|
||||||
|
// mk[2] = pk4[i + 2*(nb11/8)];
|
||||||
|
// mk[3] = pk4[i + 3*(nb11/8)];
|
||||||
|
|
||||||
|
// for (short j = 0; j < Q; ++j) {
|
||||||
|
// mqk[j] += mq[j][i] * mk;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // reduce the results from the threads in the simdgroup
|
||||||
|
// simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// for (short i = NW/2; i > 0; i /= 2) {
|
||||||
|
// if (tiisg < i) {
|
||||||
|
// for (short j = 0; j < Q; ++j) {
|
||||||
|
// mqk[j] += simd_shuffle_down(mqk[j], i);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // mqk = mqk*scale + mask
|
||||||
|
// if (tiisg == 0) {
|
||||||
|
// for (short j = 0; j < Q; ++j) {
|
||||||
|
// half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
|
||||||
|
// mqk[j] = mqk[j]*mscale + mm;
|
||||||
|
|
||||||
|
// ss4[j*T4 + cc] = mqk[j];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
for (short cc = 0; cc < C/4; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
half4 mqk[Q];
|
simdgroup_half8x8 mqk[Q];
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
mqk[j] = 0.0h;
|
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
|
||||||
}
|
}
|
||||||
|
|
||||||
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
half4x4 mk;
|
simdgroup_half8x8 mk;
|
||||||
mk[0] = pk4[i + 0*(nb11/8)];
|
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||||
mk[1] = pk4[i + 1*(nb11/8)];
|
|
||||||
mk[2] = pk4[i + 2*(nb11/8)];
|
|
||||||
mk[3] = pk4[i + 3*(nb11/8)];
|
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
mqk[j] += mq[j][i] * mk;
|
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce the results from the threads in the simdgroup
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
for (short i = NW/2; i > 0; i /= 2) {
|
|
||||||
if (tiisg < i) {
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
|
||||||
mqk[j] += simd_shuffle_down(mqk[j], i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask
|
||||||
if (tiisg == 0) {
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
|
simdgroup_half8x8 mm;
|
||||||
mqk[j] = mqk[j]*mscale + mm;
|
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||||
|
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
||||||
|
|
||||||
ss4[j*T4 + cc] = mqk[j];
|
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
half ms[Q];
|
half ms[Q];
|
||||||
|
@ -2655,8 +2687,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
ms[j] = exp(m - M[j]);
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
const half vs = exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
|
@ -2706,75 +2738,59 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce the warps sequentially
|
|
||||||
for (short sg = 1; sg < nsg; ++sg) {
|
|
||||||
half S = { 0.0h };
|
|
||||||
half M = { -INFINITY };
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// each simdgroup stores its output to shared memory, reusing sq
|
// store results to shared memory
|
||||||
if (sgitg == sg) {
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
//simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
sr4[i] = lo[j][i];
|
||||||
sq4[j*T4 + i] = lo[j][i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
// parallel reduce
|
||||||
|
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||||
// the first simdgroup accumulates the results from the other simdgroups
|
if (sgitg < r) {
|
||||||
if (sgitg == 0) {
|
if (tiisg == 0) {
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const half S0 = ss[j*T + 0];
|
const half S0 = ss[j*T + 0];
|
||||||
const half S1 = ss[j*T + sg*SH + 0];
|
const half S1 = ss[j*T + r*SH + 0];
|
||||||
|
|
||||||
const half M0 = ss[j*T + 1];
|
const half M0 = ss[j*T + 1];
|
||||||
const half M1 = ss[j*T + sg*SH + 1];
|
const half M1 = ss[j*T + r*SH + 1];
|
||||||
|
|
||||||
M = max(M0, M1);
|
const half M = max(M0, M1);
|
||||||
|
|
||||||
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
|
const half ms0 = exp(M0 - M);
|
||||||
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
|
const half ms1 = exp(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
const half S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
if (tiisg == 0) {
|
|
||||||
ss[j*T + 0] = S;
|
ss[j*T + 0] = S;
|
||||||
ss[j*T + 1] = M;
|
ss[j*T + 1] = M;
|
||||||
|
|
||||||
ss[j*T + C + j ] = ms0;
|
ss[j*T + C + j ] = ms0;
|
||||||
ss[j*T + C + j + sg*SH] = ms1;
|
ss[j*T + C + j + r*SH] = ms1;
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
|
||||||
half4 t = sq4[j*T4 + i];
|
|
||||||
half ms0 = ss[j*T + C + j];
|
|
||||||
half ms1 = ss[j*T + C + j + sg*SH];
|
|
||||||
|
|
||||||
lo[j][i] = lo[j][i]*ms0 + t*ms1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// store result to shared memory (reuse sq)
|
|
||||||
if (sgitg == 0) {
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
|
||||||
//simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
|
||||||
sq4[j*T4 + i] = lo[j][i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (sgitg < r) {
|
||||||
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
const half ms0 = ss[j*T + C + j];
|
||||||
|
const half ms1 = ss[j*T + C + j + r*SH];
|
||||||
|
|
||||||
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
|
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
device float4 * dst4 = (device float4 *) dst;
|
device float4 * dst4 = (device float4 *) dst;
|
||||||
|
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
|
@ -2783,7 +2799,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
const half S = ss[j*T + 0];
|
const half S = ss[j*T + 0];
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue