wip : disable skip
This commit is contained in:
parent
806382a3a6
commit
eb12e3c391
2 changed files with 52 additions and 72 deletions
|
@ -2253,12 +2253,12 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nsg = 8; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 8;
|
||||
|
||||
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + nsg*(2*ncpsg))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
|
|
112
ggml-metal.metal
112
ggml-metal.metal
|
@ -2046,15 +2046,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
const int64_t L4 = (D4 + N4 - 1)/N4;
|
||||
const int64_t D8 = D/8;
|
||||
|
||||
const int64_t T = D + nsg*(2*C); // shared memory size per query in half
|
||||
const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half
|
||||
const int64_t T4 = T/4; // shared memory size per query in half4
|
||||
|
||||
threadgroup half * pq = (threadgroup half *) (shared + 0*D);
|
||||
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*C) + 1*D);
|
||||
//threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D);
|
||||
threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D);
|
||||
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D);
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D);
|
||||
|
||||
half4 ps4[Q][L4];
|
||||
half4 ls4[Q][L4];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
|
@ -2067,11 +2068,11 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
ps4[j][i] = 0.0h;
|
||||
ls4[j][i] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
if (tiisg < 2) {
|
||||
if (tiisg < 1) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
ss[j*T + tiisg] = 0.0h;
|
||||
}
|
||||
|
@ -2127,15 +2128,15 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
||||
{
|
||||
bool skip = true;
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
skip = skip && (mp[j][iic] == -INFINITY);
|
||||
}
|
||||
if (skip) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
//{
|
||||
// bool skip = true;
|
||||
// for (int64_t j = 0; j < Q; ++j) {
|
||||
// skip = skip && (mp[j][iic] == -INFINITY);
|
||||
// }
|
||||
// if (skip) {
|
||||
// continue;
|
||||
// }
|
||||
//}
|
||||
|
||||
{
|
||||
simdgroup_half8x8 mk;
|
||||
|
@ -2152,33 +2153,13 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
simdgroup_store(mqk, ss, T, 0, false);
|
||||
}
|
||||
|
||||
//if (tiisg < Q) {
|
||||
// const int64_t j = tiisg;
|
||||
|
||||
// for (int p = 0; p < C; ++p) {
|
||||
// const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
||||
|
||||
// const half m = M;
|
||||
|
||||
// M = max(M, s);
|
||||
|
||||
// const half ms = m == -INFINITY ? 0.0h : exp(m - M);
|
||||
// const half vs = s == -INFINITY ? 0.0h : exp(s - M);
|
||||
|
||||
// S = S*ms + vs;
|
||||
|
||||
// ss[j*T + 0 + p] = ms;
|
||||
// ss[j*T + C + p] = vs;
|
||||
// }
|
||||
//}
|
||||
|
||||
// not sure why this barrier is needed
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = tiisg % C;
|
||||
|
||||
const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
||||
const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
|
||||
|
||||
half m = M[j];
|
||||
|
||||
|
@ -2187,45 +2168,44 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms + 0.25h*simd_sum(vs);
|
||||
S[j] = S[j]*ms + 0.25h*simd_sum(vs); // 4*8 = 32
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j][i] *= ms;
|
||||
ls4[j][i] *= ms;
|
||||
}
|
||||
|
||||
if (tiisg < C) {
|
||||
ss[j*T + C + p] = vs;
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t p = 0; p < C; ++p) {
|
||||
const int64_t ic = iic + p;
|
||||
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
{
|
||||
simdgroup_half8x8 mv;
|
||||
simdgroup_half8x8 mp;
|
||||
simdgroup_half8x8 mqkv;
|
||||
|
||||
device const half * pv = (device const half *) ((device const char *) v + (iic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
// load mp
|
||||
simdgroup_load(mp, ss, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load (mv, pv + i*8, nb21/2, 0, false);
|
||||
simdgroup_multiply(mqkv, mp, mv);
|
||||
simdgroup_store (mqkv, ps + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
// not sure why this barrier is needed too
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half vs = ss[j*T + C + p];
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j][i] += pv4[N4*i + tiisg]*vs;
|
||||
ls4[j][i] += ps4[j*T4 + N4*i + tiisg];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//for (int p = 0; p < C; ++p) {
|
||||
// const int64_t ic = iic + p;
|
||||
// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
// for (int64_t j = 0; j < Q; ++j) {
|
||||
// const half ms = ss[j*T + 0 + p];
|
||||
// const half vs = ss[j*T + C + p];
|
||||
|
||||
// for (int64_t i = 0; i < L4; ++i) {
|
||||
// ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S[j];
|
||||
|
@ -2246,7 +2226,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// store heads to shared memory - reuse pq4
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
pq4[j*T4 + N4*i + tiisg] = ps4[j][i];
|
||||
pq4[j*T4 + N4*i + tiisg] = ls4[j][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2256,10 +2236,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
if (sgitg == 0) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*(2*C) + 0];
|
||||
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
|
||||
|
||||
const half M0 = ss[j*T + 1];
|
||||
const half M1 = ss[j*T + sg*(2*C) + 1];
|
||||
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
|
@ -2274,7 +2254,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j][i] = ps4[j][i]*ms0 + pq4[j*T4 + N4*i + tiisg]*ms1;
|
||||
ls4[j][i] = ls4[j][i]*ms0 + pq4[j*T4 + N4*i + tiisg]*ms1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2286,7 +2266,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j][i] = ps4[j][i]/S;
|
||||
ls4[j][i] = ls4[j][i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2298,7 +2278,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
if (sgitg == 0) {
|
||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j][i];
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ls4[j][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue