metal : simplify
This commit is contained in:
parent
c4dff1ec91
commit
f8d709f01a
2 changed files with 54 additions and 87 deletions
11
ggml-metal.m
11
ggml-metal.m
|
@ -2573,7 +2573,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
if (ne01 > 1) {
|
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
|
@ -2603,8 +2603,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
//const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||||
const int64_t nsg = 8;
|
|
||||||
|
int64_t nsg = 1;
|
||||||
|
while (nsg <= nsgt) {
|
||||||
|
nsg *= 2;
|
||||||
|
}
|
||||||
|
nsg /= 2;
|
||||||
|
|
||||||
// require power of 2
|
// require power of 2
|
||||||
//{
|
//{
|
||||||
|
|
130
ggml-metal.metal
130
ggml-metal.metal
|
@ -2575,21 +2575,20 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
const short iv3 = iq3 / rv3;
|
const short iv3 = iq3 / rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
simdgroup_half8x8 mq[Q][D8];
|
half4 mq[Q][D4];
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
|
short i = ii + tiisg;
|
||||||
|
mq[j][i] = sq4[j*T4 + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
//device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
||||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
|
||||||
|
|
||||||
// prepare diagonal scale matrix
|
// prepare diagonal scale matrix
|
||||||
simdgroup_half8x8 mscale(scale);
|
half mscale(scale);
|
||||||
//half mscale(scale);
|
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
|
@ -2599,79 +2598,45 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Q*K^T
|
|
||||||
//{
|
|
||||||
// for (short cc = 0; cc < C/4; ++cc) {
|
|
||||||
// half4 mqk[Q];
|
|
||||||
// for (short j = 0; j < Q; ++j) {
|
|
||||||
// mqk[j] = 0.0h;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
|
||||||
|
|
||||||
// for (short i = tiisg; i < D4; i += NW) {
|
|
||||||
// half4x4 mk;
|
|
||||||
// mk[0] = pk4[i + 0*(nb11/8)];
|
|
||||||
// mk[1] = pk4[i + 1*(nb11/8)];
|
|
||||||
// mk[2] = pk4[i + 2*(nb11/8)];
|
|
||||||
// mk[3] = pk4[i + 3*(nb11/8)];
|
|
||||||
|
|
||||||
// for (short j = 0; j < Q; ++j) {
|
|
||||||
// mqk[j] += mq[j][i] * mk;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // reduce the results from the threads in the simdgroup
|
|
||||||
// simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// for (short i = NW/2; i > 0; i /= 2) {
|
|
||||||
// if (tiisg < i) {
|
|
||||||
// for (short j = 0; j < Q; ++j) {
|
|
||||||
// mqk[j] += simd_shuffle_down(mqk[j], i);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // mqk = mqk*scale + mask
|
|
||||||
// if (tiisg == 0) {
|
|
||||||
// for (short j = 0; j < Q; ++j) {
|
|
||||||
// half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
|
|
||||||
// mqk[j] = mqk[j]*mscale + mm;
|
|
||||||
|
|
||||||
// ss4[j*T4 + cc] = mqk[j];
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
for (short cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/4; ++cc) {
|
||||||
simdgroup_half8x8 mqk[Q];
|
half4 mqk[Q] = { [0 ... Q-1] = 0.0h };
|
||||||
for (short j = 0; j < Q; ++j) {
|
|
||||||
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
|
|
||||||
}
|
|
||||||
|
|
||||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (short i = 0; i < D8; ++i) {
|
#pragma unroll
|
||||||
simdgroup_half8x8 mk;
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
const short i = ii + tiisg;
|
||||||
|
|
||||||
|
half4x4 mk;
|
||||||
|
mk[0] = pk4[i + 0*(nb11/8)];
|
||||||
|
mk[1] = pk4[i + 1*(nb11/8)];
|
||||||
|
mk[2] = pk4[i + 2*(nb11/8)];
|
||||||
|
mk[3] = pk4[i + 3*(nb11/8)];
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
mqk[j] += mq[j][i] * mk;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// reduce the results from the threads in the simdgroup
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
simdgroup_half8x8 mm;
|
mqk[j] += simd_shuffle_down(mqk[j], 16);
|
||||||
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
mqk[j] += simd_shuffle_down(mqk[j], 8);
|
||||||
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
mqk[j] += simd_shuffle_down(mqk[j], 4);
|
||||||
|
mqk[j] += simd_shuffle_down(mqk[j], 2);
|
||||||
|
mqk[j] += simd_shuffle_down(mqk[j], 1);
|
||||||
|
}
|
||||||
|
|
||||||
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
// mqk = mqk*scale + mask
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
|
||||||
|
mqk[j] = mqk[j]*mscale + mm;
|
||||||
|
|
||||||
|
ss4[j*T4 + cc] = mqk[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2701,26 +2666,26 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||||
}
|
}
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
//simdgroup_half8x8 mm;
|
|
||||||
//simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
|
|
||||||
half mm(ss[j*T + C + j]);
|
half mm(ss[j*T + C + j]);
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
#pragma unroll
|
||||||
//simdgroup_multiply(lo[j][i], mm, lo[j][i]);
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
const short i = ii + tiisg;
|
||||||
lo[j][i/NW] = lo[j][i/NW]*mm;
|
lo[j][i/NW] = lo[j][i/NW]*mm;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// O = O + (Q*K^T)*V
|
// O = O + (Q*K^T)*V
|
||||||
{
|
{
|
||||||
|
#pragma unroll
|
||||||
for (short cc = 0; cc < C; ++cc) {
|
for (short cc = 0; cc < C; ++cc) {
|
||||||
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
#pragma unroll
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
|
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
|
||||||
}
|
}
|
||||||
|
@ -2738,15 +2703,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// store results to shared memory
|
// store results to shared memory
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
sr4[i] = lo[j][i/NW];
|
short i = ii + tiisg;
|
||||||
|
sr4[i] = lo[j][ii/NW];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// parallel reduce
|
// parallel reduce
|
||||||
for (short r = nsg/2; r > 0; r >>= 1) {
|
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||||
if (sgitg < r) {
|
if (sgitg < r) {
|
||||||
|
@ -2805,10 +2771,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>;
|
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>;
|
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>;
|
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>;
|
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue