metal : opts

This commit is contained in:
Georgi Gerganov 2024-04-05 13:57:54 +03:00
parent 5eab7454dd
commit 8d2a61f068
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 5 additions and 31 deletions

View file

@ -2616,7 +2616,6 @@ static enum ggml_status ggml_metal_graph_compute(
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
//const int64_t nsg = 1;
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);

View file

@ -2529,7 +2529,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
// zero out lo
for (short j = 0; j < Q; ++j) {
for (short i = 0; i < D4; ++i) {
for (short i = tiisg; i < D4; i += NW) {
lo[j][i] = 0.0h;
}
}
@ -2648,10 +2648,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// used to detect blocks full of -INF
half smax = -INFINITY;
//threadgroup_barrier(mem_flags::mem_threadgroup);
// online softmax
if (C == 32) {
@ -2663,7 +2660,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
const half m = M[j];
const half s = ss[j*T + p];
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
@ -2688,11 +2684,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
smax = max(smax, s);
M[j] = max(M[j], s);
}
smax = simd_max(smax);
M[j] = simd_max(M[j]);
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
@ -2720,12 +2714,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
}
}
// skip -INF blocks
if (smax == -INFINITY) {
continue;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
//threadgroup_barrier(mem_flags::mem_threadgroup);
// O = diag(ms)*O
for (short j = 0; j < Q; ++j) {
@ -2742,26 +2731,12 @@ kernel void kernel_flash_attn_ext_vec_f16(
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C; ++cc) {
//device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
//for (short i = 0; i < D8; ++i) {
// simdgroup_half8x8 mk;
// simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
// for (short j = 0; j < Q8; ++j) {
// simdgroup_half8x8 mv;
// simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
// simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
// }
//}
half vsum[Q];
for (short i = tiisg; i < D4; i += NW) {
half4 mk = pv4[i];
for (short j = 0; j < Q; ++j) {
lo[j][i] += mk*ss[j*T + cc];
lo[j][i] += pv4[i]*ss[j*T + cc];
}
}
}