metal : simplify

This commit is contained in:
Georgi Gerganov 2024-04-05 16:29:29 +03:00
parent c4dff1ec91
commit f8d709f01a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 54 additions and 87 deletions

View file

@ -2573,7 +2573,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&scale length:sizeof( float) atIndex:27]; [encoder setBytes:&scale length:sizeof( float) atIndex:27];
// half8x8 kernel // half8x8 kernel
if (ne01 > 1) { if (ne01 > 1 || (ne00%128 != 0)) {
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
@ -2603,8 +2603,13 @@ static enum ggml_status ggml_metal_graph_compute(
// simdgroups per threadgroup (a.k.a. warps) // simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
//const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
const int64_t nsg = 8;
int64_t nsg = 1;
while (nsg <= nsgt) {
nsg *= 2;
}
nsg /= 2;
// require power of 2 // require power of 2
//{ //{

View file

@ -2575,21 +2575,20 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3; const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
simdgroup_half8x8 mq[Q][D8]; half4 mq[Q][D4];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
for (short i = 0; i < D8; ++i) { for (short ii = 0; ii < D4; ii += NW) {
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); short i = ii + tiisg;
mq[j][i] = sq4[j*T4 + i];
} }
} }
// pointer to the mask // pointer to the mask
//device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
device const half * mp = (device const half *) (mask + iq1*nb31);
// prepare diagonal scale matrix // prepare diagonal scale matrix
simdgroup_half8x8 mscale(scale); half mscale(scale);
//half mscale(scale);
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
@ -2599,79 +2598,45 @@ kernel void kernel_flash_attn_ext_vec_f16(
break; break;
} }
// Q*K^T
//{
// for (short cc = 0; cc < C/4; ++cc) {
// half4 mqk[Q];
// for (short j = 0; j < Q; ++j) {
// mqk[j] = 0.0h;
// }
// device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
// for (short i = tiisg; i < D4; i += NW) {
// half4x4 mk;
// mk[0] = pk4[i + 0*(nb11/8)];
// mk[1] = pk4[i + 1*(nb11/8)];
// mk[2] = pk4[i + 2*(nb11/8)];
// mk[3] = pk4[i + 3*(nb11/8)];
// for (short j = 0; j < Q; ++j) {
// mqk[j] += mq[j][i] * mk;
// }
// }
// // reduce the results from the threads in the simdgroup
// simdgroup_barrier(mem_flags::mem_none);
// for (short i = NW/2; i > 0; i /= 2) {
// if (tiisg < i) {
// for (short j = 0; j < Q; ++j) {
// mqk[j] += simd_shuffle_down(mqk[j], i);
// }
// }
// simdgroup_barrier(mem_flags::mem_none);
// }
// // mqk = mqk*scale + mask
// if (tiisg == 0) {
// for (short j = 0; j < Q; ++j) {
// half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
// mqk[j] = mqk[j]*mscale + mm;
// ss4[j*T4 + cc] = mqk[j];
// }
// }
// }
//}
// Q*K^T // Q*K^T
{ {
for (short cc = 0; cc < C/8; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
simdgroup_half8x8 mqk[Q]; half4 mqk[Q] = { [0 ... Q-1] = 0.0h };
for (short j = 0; j < Q; ++j) {
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
}
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (short i = 0; i < D8; ++i) { #pragma unroll
simdgroup_half8x8 mk; for (short ii = 0; ii < D4; ii += NW) {
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose const short i = ii + tiisg;
half4x4 mk;
mk[0] = pk4[i + 0*(nb11/8)];
mk[1] = pk4[i + 1*(nb11/8)];
mk[2] = pk4[i + 2*(nb11/8)];
mk[3] = pk4[i + 3*(nb11/8)];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); mqk[j] += mq[j][i] * mk;
} }
} }
// reduce the results from the threads in the simdgroup
for (short j = 0; j < Q; ++j) {
mqk[j] += simd_shuffle_down(mqk[j], 16);
mqk[j] += simd_shuffle_down(mqk[j], 8);
mqk[j] += simd_shuffle_down(mqk[j], 4);
mqk[j] += simd_shuffle_down(mqk[j], 2);
mqk[j] += simd_shuffle_down(mqk[j], 1);
}
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
if (tiisg == 0) {
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
simdgroup_half8x8 mm; half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); mqk[j] = mqk[j]*mscale + mm;
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); ss4[j*T4 + cc] = mqk[j];
}
} }
} }
} }
@ -2701,26 +2666,26 @@ kernel void kernel_flash_attn_ext_vec_f16(
ss[tiisg*T + C + tiisg] = ms[tiisg]; ss[tiisg*T + C + tiisg] = ms[tiisg];
} }
//threadgroup_barrier(mem_flags::mem_threadgroup);
// O = diag(ms)*O // O = diag(ms)*O
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
//simdgroup_half8x8 mm;
//simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
half mm(ss[j*T + C + j]); half mm(ss[j*T + C + j]);
for (short i = tiisg; i < D4; i += NW) { #pragma unroll
//simdgroup_multiply(lo[j][i], mm, lo[j][i]); for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
lo[j][i/NW] = lo[j][i/NW]*mm; lo[j][i/NW] = lo[j][i/NW]*mm;
} }
} }
// O = O + (Q*K^T)*V // O = O + (Q*K^T)*V
{ {
#pragma unroll
for (short cc = 0; cc < C; ++cc) { for (short cc = 0; cc < C; ++cc) {
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
for (short i = tiisg; i < D4; i += NW) { #pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
lo[j][i/NW] += pv4[i]*ss[j*T + cc]; lo[j][i/NW] += pv4[i]*ss[j*T + cc];
} }
@ -2738,15 +2703,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// store results to shared memory // store results to shared memory
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < D4; i += NW) { for (short ii = 0; ii < D4; ii += NW) {
sr4[i] = lo[j][i/NW]; short i = ii + tiisg;
sr4[i] = lo[j][ii/NW];
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// parallel reduce // parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) { for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) { if (sgitg < r) {
@ -2805,10 +2771,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
} }
} }
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;