metal : opts
This commit is contained in:
parent
5eab7454dd
commit
8d2a61f068
2 changed files with 5 additions and 31 deletions
|
@ -2616,7 +2616,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||||
//const int64_t nsg = 1;
|
|
||||||
|
|
||||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
|
|
@ -2529,7 +2529,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
|
||||||
// zero out lo
|
// zero out lo
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
for (short i = 0; i < D4; ++i) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
lo[j][i] = 0.0h;
|
lo[j][i] = 0.0h;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2648,10 +2648,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// used to detect blocks full of -INF
|
|
||||||
half smax = -INFINITY;
|
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (C == 32) {
|
if (C == 32) {
|
||||||
|
@ -2663,7 +2660,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = simd_max(max(smax, s));
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
@ -2688,11 +2684,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
for (short p = tiisg; p < C; p += NW) {
|
for (short p = tiisg; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = max(smax, s);
|
|
||||||
M[j] = max(M[j], s);
|
M[j] = max(M[j], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
smax = simd_max(smax);
|
|
||||||
M[j] = simd_max(M[j]);
|
M[j] = simd_max(M[j]);
|
||||||
|
|
||||||
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
@ -2720,12 +2714,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip -INF blocks
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (smax == -INFINITY) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
@ -2742,26 +2731,12 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
// O = O + (Q*K^T)*V
|
// O = O + (Q*K^T)*V
|
||||||
{
|
{
|
||||||
for (short cc = 0; cc < C; ++cc) {
|
for (short cc = 0; cc < C; ++cc) {
|
||||||
//device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
|
||||||
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
//for (short i = 0; i < D8; ++i) {
|
half vsum[Q];
|
||||||
// simdgroup_half8x8 mk;
|
|
||||||
// simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
|
||||||
|
|
||||||
// for (short j = 0; j < Q8; ++j) {
|
|
||||||
// simdgroup_half8x8 mv;
|
|
||||||
// simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
|
|
||||||
|
|
||||||
// simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
half4 mk = pv4[i];
|
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
lo[j][i] += mk*ss[j*T + cc];
|
lo[j][i] += pv4[i]*ss[j*T + cc];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue