metal : opt
This commit is contained in:
parent
8d2a61f068
commit
5733b00e53
1 changed files with 28 additions and 74 deletions
102
ggml-metal.metal
102
ggml-metal.metal
|
@ -2581,7 +2581,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
||||||
|
|
||||||
// prepare diagonal scale matrix
|
// prepare diagonal scale matrix
|
||||||
//simdgroup_half8x8 mscale(scale);
|
//simdgroup_half8x8 mscale(scale);
|
||||||
|
@ -2597,23 +2597,23 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
for (short cc = 0; cc < C; ++cc) {
|
for (short cc = 0; cc < C/4; ++cc) {
|
||||||
half mqk[Q];
|
half4 mqk[Q];
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
mqk[j] = 0.0h;
|
mqk[j] = 0.0h;
|
||||||
}
|
}
|
||||||
|
|
||||||
//device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13));
|
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
//simdgroup_half8x8 mk;
|
half4x4 mk;
|
||||||
//simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
mk[0] = pk4[i + 0*(nb11/8)];
|
||||||
half4 mk = pk4[i];
|
mk[1] = pk4[i + 1*(nb11/8)];
|
||||||
|
mk[2] = pk4[i + 2*(nb11/8)];
|
||||||
|
mk[3] = pk4[i + 3*(nb11/8)];
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
//simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
mqk[j] += mq[j][i] * mk;
|
||||||
mqk[j] += dot(mq[j][i], mk);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2633,85 +2633,40 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
//simdgroup_half8x8 mm;
|
half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
|
||||||
//simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
|
||||||
//simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
|
||||||
|
|
||||||
//simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
|
||||||
|
|
||||||
half mm = mp[j*(nb31/sizeof(half)) + ic + cc];
|
|
||||||
mqk[j] = mqk[j]*mscale + mm;
|
mqk[j] = mqk[j]*mscale + mm;
|
||||||
|
|
||||||
ss[j*T + cc] = mqk[j];
|
ss4[j*T4 + cc] = mqk[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (C == 32) {
|
half ms[Q];
|
||||||
half ms[Q];
|
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const short p = tiisg;
|
const short p = tiisg;
|
||||||
|
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (tiisg < Q) {
|
if (tiisg < Q) {
|
||||||
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||||
}
|
|
||||||
} else {
|
|
||||||
half ms[Q];
|
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
|
||||||
const half m = M[j];
|
|
||||||
|
|
||||||
for (short p = tiisg; p < C; p += NW) {
|
|
||||||
const half s = ss[j*T + p];
|
|
||||||
|
|
||||||
M[j] = max(M[j], s);
|
|
||||||
}
|
|
||||||
|
|
||||||
M[j] = simd_max(M[j]);
|
|
||||||
|
|
||||||
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
|
||||||
|
|
||||||
// local sum
|
|
||||||
half ls = 0.0h;
|
|
||||||
|
|
||||||
for (short p = tiisg; p < C; p += NW) {
|
|
||||||
const half s = ss[j*T + p];
|
|
||||||
|
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
|
||||||
|
|
||||||
ls += vs;
|
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
|
||||||
ss[j*T + p] = vs;
|
|
||||||
}
|
|
||||||
|
|
||||||
S[j] = S[j]*ms[j] + simd_sum(ls);
|
|
||||||
}
|
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
|
||||||
if (tiisg < Q) {
|
|
||||||
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
@ -2733,7 +2688,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
for (short cc = 0; cc < C; ++cc) {
|
for (short cc = 0; cc < C; ++cc) {
|
||||||
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
half vsum[Q];
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
lo[j][i] += pv4[i]*ss[j*T + cc];
|
lo[j][i] += pv4[i]*ss[j*T + cc];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue