diff --git a/ggml-metal.m b/ggml-metal.m index bf6277d38..d942b673f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2615,8 +2615,8 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - //const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - const int64_t nsg = 1; + const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + //const int64_t nsg = 1; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index e4be0f69e..d7ce10274 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2608,9 +2608,8 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short i = tiisg; i < D4; i += NW) { //simdgroup_half8x8 mk; - half4 mk; //simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - mk = pk4[i]; + half4 mk = pk4[i]; for (short j = 0; j < Q; ++j) { //simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); @@ -2779,66 +2778,61 @@ kernel void kernel_flash_attn_ext_vec_f16( } // reduce the warps sequentially - //for (short sg = 1; sg < nsg; ++sg) { - // half S = { 0.0h }; - // half M = { -INFINITY }; + for (short sg = 1; sg < nsg; ++sg) { + half S = { 0.0h }; + half M = { -INFINITY }; - // threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // // each simdgroup stores its output to shared memory, reusing sq - // if (sgitg == sg) { - // for (short j = 0; j < Q8; ++j) { - // for (short i = 0; i < D8; ++i) { - // simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - // } - // } - // } + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + //simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + sq4[j*T4 + i] = lo[j][i]; + } + } + } - // threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // // the first simdgroup accumulates the results from the other simdgroups - // if (sgitg == 0) { - // for (short j = 0; j < Q; ++j) { - // const half S0 = ss[j*T + 0]; - // const half S1 = ss[j*T + sg*SH + 0]; + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - // const half M0 = ss[j*T + 1]; - // const half M1 = ss[j*T + sg*SH + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; - // M = max(M0, M1); + M = max(M0, M1); - // const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); - // const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); - // S = S0*ms0 + S1*ms1; + S = S0*ms0 + S1*ms1; - // if (tiisg == 0) { - // ss[j*T + 0] = S; - // ss[j*T + 1] = M; + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; - // ss[j*T + C + j ] = ms0; - // ss[j*T + C + j + sg*SH] = ms1; - // } - // } + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } - // // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - // for (short j = 0; j < Q8; ++j) { - // simdgroup_half8x8 t; - // simdgroup_half8x8 ms0; - // simdgroup_half8x8 ms1; + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + half4 t = sq4[j*T4 + i]; + half ms0 = ss[j*T + C + j]; + half ms1 = ss[j*T + C + j + sg*SH]; - // simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); - // simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); - - // for (short i = 0; i < D8; ++i) { - // simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); - // simdgroup_multiply(t, ms1, t); - - // simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); - // } - // } - // } - //} + lo[j][i] = lo[j][i]*ms0 + t*ms1; + } + } + } + } // store result to shared memory (reuse sq) if (sgitg == 0) {