metal : support more than 1 warps
This commit is contained in:
parent
d15898481a
commit
5eab7454dd
2 changed files with 46 additions and 52 deletions
|
@ -2615,8 +2615,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
//const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||||
const int64_t nsg = 1;
|
//const int64_t nsg = 1;
|
||||||
|
|
||||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
|
|
@ -2608,9 +2608,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
//simdgroup_half8x8 mk;
|
//simdgroup_half8x8 mk;
|
||||||
half4 mk;
|
|
||||||
//simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
//simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||||
mk = pk4[i];
|
half4 mk = pk4[i];
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
//simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
//simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
||||||
|
@ -2779,66 +2778,61 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
//for (short sg = 1; sg < nsg; ++sg) {
|
for (short sg = 1; sg < nsg; ++sg) {
|
||||||
// half S = { 0.0h };
|
half S = { 0.0h };
|
||||||
// half M = { -INFINITY };
|
half M = { -INFINITY };
|
||||||
|
|
||||||
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// // each simdgroup stores its output to shared memory, reusing sq
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
// if (sgitg == sg) {
|
if (sgitg == sg) {
|
||||||
// for (short j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
// for (short i = 0; i < D8; ++i) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
// simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
//simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
||||||
// }
|
sq4[j*T4 + i] = lo[j][i];
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// // the first simdgroup accumulates the results from the other simdgroups
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
// if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
// for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
// const half S0 = ss[j*T + 0];
|
const half S0 = ss[j*T + 0];
|
||||||
// const half S1 = ss[j*T + sg*SH + 0];
|
const half S1 = ss[j*T + sg*SH + 0];
|
||||||
|
|
||||||
// const half M0 = ss[j*T + 1];
|
const half M0 = ss[j*T + 1];
|
||||||
// const half M1 = ss[j*T + sg*SH + 1];
|
const half M1 = ss[j*T + sg*SH + 1];
|
||||||
|
|
||||||
// M = max(M0, M1);
|
M = max(M0, M1);
|
||||||
|
|
||||||
// const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
|
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
|
||||||
// const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
|
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
|
||||||
|
|
||||||
// S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
// if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
// ss[j*T + 0] = S;
|
ss[j*T + 0] = S;
|
||||||
// ss[j*T + 1] = M;
|
ss[j*T + 1] = M;
|
||||||
|
|
||||||
// ss[j*T + C + j ] = ms0;
|
ss[j*T + C + j ] = ms0;
|
||||||
// ss[j*T + C + j + sg*SH] = ms1;
|
ss[j*T + C + j + sg*SH] = ms1;
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
// for (short j = 0; j < Q8; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
// simdgroup_half8x8 t;
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
// simdgroup_half8x8 ms0;
|
half4 t = sq4[j*T4 + i];
|
||||||
// simdgroup_half8x8 ms1;
|
half ms0 = ss[j*T + C + j];
|
||||||
|
half ms1 = ss[j*T + C + j + sg*SH];
|
||||||
|
|
||||||
// simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
|
lo[j][i] = lo[j][i]*ms0 + t*ms1;
|
||||||
// simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
|
}
|
||||||
|
}
|
||||||
// for (short i = 0; i < D8; ++i) {
|
}
|
||||||
// simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
|
}
|
||||||
// simdgroup_multiply(t, ms1, t);
|
|
||||||
|
|
||||||
// simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
// store result to shared memory (reuse sq)
|
// store result to shared memory (reuse sq)
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue