metal : opt

This commit is contained in:
Georgi Gerganov 2024-04-05 14:26:28 +03:00
parent 8d2a61f068
commit 5733b00e53
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2581,7 +2581,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
} }
// pointer to the mask // pointer to the mask
device const half * mp = (device const half *) (mask + iq1*nb31); device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
// prepare diagonal scale matrix // prepare diagonal scale matrix
//simdgroup_half8x8 mscale(scale); //simdgroup_half8x8 mscale(scale);
@ -2597,23 +2597,23 @@ kernel void kernel_flash_attn_ext_vec_f16(
// Q*K^T // Q*K^T
{ {
for (short cc = 0; cc < C; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
half mqk[Q]; half4 mqk[Q];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
mqk[j] = 0.0h; mqk[j] = 0.0h;
} }
//device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13));
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
//simdgroup_half8x8 mk; half4x4 mk;
//simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose mk[0] = pk4[i + 0*(nb11/8)];
half4 mk = pk4[i]; mk[1] = pk4[i + 1*(nb11/8)];
mk[2] = pk4[i + 2*(nb11/8)];
mk[3] = pk4[i + 3*(nb11/8)];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
//simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); mqk[j] += mq[j][i] * mk;
mqk[j] += dot(mq[j][i], mk);
} }
} }
@ -2633,85 +2633,40 @@ kernel void kernel_flash_attn_ext_vec_f16(
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
if (tiisg == 0) { if (tiisg == 0) {
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
//simdgroup_half8x8 mm; half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
//simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
//simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
//simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
half mm = mp[j*(nb31/sizeof(half)) + ic + cc];
mqk[j] = mqk[j]*mscale + mm; mqk[j] = mqk[j]*mscale + mm;
ss[j*T + cc] = mqk[j]; ss4[j*T4 + cc] = mqk[j];
} }
} }
} }
} }
//threadgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
// online softmax // online softmax
if (C == 32) { half ms[Q];
half ms[Q];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const short p = tiisg; const short p = tiisg;
const half m = M[j]; const half m = M[j];
const half s = ss[j*T + p]; const half s = ss[j*T + p];
M[j] = simd_max(max(M[j], s)); M[j] = simd_max(max(M[j], s));
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j]*ms[j] + simd_sum(vs); S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns) // the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs; ss[j*T + p] = vs;
} }
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) { if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg]; ss[tiisg*T + C + tiisg] = ms[tiisg];
}
} else {
half ms[Q];
for (short j = 0; j < Q; ++j) {
const half m = M[j];
for (short p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
M[j] = max(M[j], s);
}
M[j] = simd_max(M[j]);
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
// local sum
half ls = 0.0h;
for (short p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
S[j] = S[j]*ms[j] + simd_sum(ls);
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg];
}
} }
//threadgroup_barrier(mem_flags::mem_threadgroup); //threadgroup_barrier(mem_flags::mem_threadgroup);
@ -2733,7 +2688,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short cc = 0; cc < C; ++cc) { for (short cc = 0; cc < C; ++cc) {
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
half vsum[Q];
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
lo[j][i] += pv4[i]*ss[j*T + cc]; lo[j][i] += pv4[i]*ss[j*T + cc];