metal : use F32 attention accumulators

This commit is contained in:
Georgi Gerganov 2024-04-18 20:08:52 +03:00
parent fa9e8c6689
commit c16a7c2688
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 81 additions and 93 deletions

View file

@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ncpsg % 32 == 0); GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps) // simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ncpsg % 32 == 0); GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps) // simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
int64_t nsg = 1; int64_t nsg = 1;
@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute(
} }
nsg /= 2; nsg /= 2;
// require power of 2 const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
//{
// int64_t nsgm = 1;
// while (nsgm < nsg) {
// nsgm *= 2;
// }
// GGML_ASSERT(nsg == nsgm);
//}
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);

View file

@ -2169,12 +2169,13 @@ kernel void kernel_flash_attn_ext_f16(
const short NW = N_SIMDWIDTH; const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half) const short SH = (C + Q); // shared memory per simdgroup in (half)
const short T = D + nsg*SH; // shared memory size per query in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half)
const short TF = T/2; // shared memory size per query in (float)
const short T4 = T/4; // shared memory size per query in (half4) const short T4 = T/4; // shared memory size per query in (half4)
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[Q8][D8]; simdgroup_half8x8 lo[Q8][D8];
@ -2202,15 +2203,15 @@ kernel void kernel_flash_attn_ext_f16(
// zero out shared memory SH // zero out shared memory SH
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH; i += NW) { for (short i = tiisg; i < SH; i += NW) {
ss[j*T + i] = 0.0h; ss[j*TF + i] = 0.0f;
} }
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
half S[Q] = { [0 ... Q-1] = 0.0h }; float S[Q] = { [0 ... Q-1] = 0.0h };
half M[Q] = { [0 ... Q-1] = -INFINITY }; float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
// assume K and V are same shape // assume K and V are same shape
const short ne22 = ne12; const short ne22 = ne12;
@ -2248,7 +2249,7 @@ kernel void kernel_flash_attn_ext_f16(
device const half * mp = (device const half *) (mask + iq1*nb31); device const half * mp = (device const half *) (mask + iq1*nb31);
// prepare diagonal scale matrix // prepare diagonal scale matrix
simdgroup_half8x8 mscale(scale); simdgroup_float8x8 mscale(scale);
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
@ -2261,9 +2262,9 @@ kernel void kernel_flash_attn_ext_f16(
// Q*K^T // Q*K^T
{ {
for (short cc = 0; cc < C/8; ++cc) { for (short cc = 0; cc < C/8; ++cc) {
simdgroup_half8x8 mqk[Q8]; simdgroup_float8x8 mqk[Q8];
for (short j = 0; j < Q8; ++j) { for (short j = 0; j < Q8; ++j) {
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h); mqk[j] = make_filled_simdgroup_matrix<float, 8>(0.h);
} }
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -2283,48 +2284,48 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false);
} }
} }
} }
// used to detect blocks full of -INF // used to detect blocks full of -INF
half smax = -INFINITY; float smax = -INFINITY;
// online softmax // online softmax
if (C == 32) { if (C == 32) {
half ms[Q]; float ms[Q];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const short p = tiisg; const short p = tiisg;
const half m = M[j]; const float m = M[j];
const half s = ss[j*T + p]; const float s = ss[j*TF + p];
smax = simd_max(max(smax, s)); smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s)); M[j] = simd_max(max(M[j], s));
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); ms[j] = exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); const float vs = exp(s - M[j]);
S[j] = S[j]*ms[j] + simd_sum(vs); S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns) // the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs; ss[j*TF + p] = vs;
} }
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) { if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg]; ss[tiisg*TF + C + tiisg] = ms[tiisg];
} }
} else { } else {
half ms[Q]; float ms[Q];
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const half m = M[j]; const float m = M[j];
for (short p = tiisg; p < C; p += NW) { for (short p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p]; const float s = ss[j*TF + p];
smax = max(smax, s); smax = max(smax, s);
M[j] = max(M[j], s); M[j] = max(M[j], s);
@ -2333,20 +2334,20 @@ kernel void kernel_flash_attn_ext_f16(
smax = simd_max(smax); smax = simd_max(smax);
M[j] = simd_max(M[j]); M[j] = simd_max(M[j]);
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); ms[j] = exp(m - M[j]);
// local sum // local sum
half ls = 0.0h; float ls = 0.0h;
for (short p = tiisg; p < C; p += NW) { for (short p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p]; const float s = ss[j*TF + p];
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); const float vs = exp(s - M[j]);
ls += vs; ls += vs;
// the P matrix from the paper (Q rows, C columns) // the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs; ss[j*TF + p] = vs;
} }
S[j] = S[j]*ms[j] + simd_sum(ls); S[j] = S[j]*ms[j] + simd_sum(ls);
@ -2354,7 +2355,7 @@ kernel void kernel_flash_attn_ext_f16(
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) { if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg]; ss[tiisg*TF + C + tiisg] = ms[tiisg];
} }
} }
@ -2365,8 +2366,8 @@ kernel void kernel_flash_attn_ext_f16(
// O = diag(ms)*O // O = diag(ms)*O
for (short j = 0; j < Q8; ++j) { for (short j = 0; j < Q8; ++j) {
simdgroup_half8x8 mm; simdgroup_float8x8 mm;
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false);
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_multiply(lo[j][i], mm, lo[j][i]); simdgroup_multiply(lo[j][i], mm, lo[j][i]);
@ -2383,8 +2384,8 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
for (short j = 0; j < Q8; ++j) { for (short j = 0; j < Q8; ++j) {
simdgroup_half8x8 mv; simdgroup_float8x8 mv;
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false);
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
} }
@ -2396,16 +2397,16 @@ kernel void kernel_flash_attn_ext_f16(
// these are needed for reducing the results from the simdgroups (reuse the ss buffer) // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
if (tiisg == 0) { if (tiisg == 0) {
ss[j*T + 0] = S[j]; ss[j*TF + 0] = S[j];
ss[j*T + 1] = M[j]; ss[j*TF + 1] = M[j];
} }
} }
} }
// reduce the warps sequentially // reduce the warps sequentially
for (short sg = 1; sg < nsg; ++sg) { for (short sg = 1; sg < nsg; ++sg) {
half S = { 0.0h }; float S = { 0.0h };
half M = { -INFINITY }; float M = { -FLT_MAX/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -2423,36 +2424,36 @@ kernel void kernel_flash_attn_ext_f16(
// the first simdgroup accumulates the results from the other simdgroups // the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) { if (sgitg == 0) {
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0]; const float S0 = ss[j*TF + 0];
const half S1 = ss[j*T + sg*SH + 0]; const float S1 = ss[j*TF + sg*SH + 0];
const half M0 = ss[j*T + 1]; const float M0 = ss[j*TF + 1];
const half M1 = ss[j*T + sg*SH + 1]; const float M1 = ss[j*TF + sg*SH + 1];
M = max(M0, M1); M = max(M0, M1);
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); const float ms0 = exp(M0 - M);
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); const float ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1; S = S0*ms0 + S1*ms1;
if (tiisg == 0) { if (tiisg == 0) {
ss[j*T + 0] = S; ss[j*TF + 0] = S;
ss[j*T + 1] = M; ss[j*TF + 1] = M;
ss[j*T + C + j ] = ms0; ss[j*TF + C + j ] = ms0;
ss[j*T + C + j + sg*SH] = ms1; ss[j*TF + C + j + sg*SH] = ms1;
} }
} }
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short j = 0; j < Q8; ++j) { for (short j = 0; j < Q8; ++j) {
simdgroup_half8x8 t; simdgroup_half8x8 t;
simdgroup_half8x8 ms0; simdgroup_float8x8 ms0;
simdgroup_half8x8 ms1; simdgroup_float8x8 ms1;
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false);
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false);
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
@ -2478,7 +2479,7 @@ kernel void kernel_flash_attn_ext_f16(
// final rescale with 1/S and store to global memory // final rescale with 1/S and store to global memory
if (sgitg == 0) { if (sgitg == 0) {
for (short j = 0; j < Q && iq1 + j < ne01; ++j) { for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0]; const float S = ss[j*TF + 0];
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
@ -2494,8 +2495,6 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_vec_f16( kernel void kernel_flash_attn_ext_vec_f16(
device const char * q, device const char * q,
@ -2539,18 +2538,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iq1 = tgpig[0]; const short iq1 = tgpig[0];
const short D4 = D/4; const short D4 = D/4;
const short D8 = D/8;
const short NW = N_SIMDWIDTH; const short NW = N_SIMDWIDTH;
const short SH = (C + 1); // shared memory per simdgroup in (half) const short SH = (C + 1); // shared memory per simdgroup in (half)
const short T = D + nsg*SH; // shared memory size per query in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half)
const short T4 = T/4; // shared memory size per query in (half4)
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half4 lo[D4/NW]; half4 lo[D4/NW];
@ -2579,8 +2576,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
half S = { 0.0h }; float S = { 0.0h };
half M = { -HALF_MAX_HALF }; float M = { -FLT_MAX/2 };
// assume K and V are same shape // assume K and V are same shape
const short ne22 = ne12; const short ne22 = ne12;
@ -2628,7 +2625,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
{ {
#pragma unroll #pragma unroll
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
half4 mqk = { 0.0h }; float4 mqk = { 0.0h };
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -2642,7 +2639,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
mk[2] = pk4[i + 2*(nb11/8)]; mk[2] = pk4[i + 2*(nb11/8)];
mk[3] = pk4[i + 3*(nb11/8)]; mk[3] = pk4[i + 3*(nb11/8)];
mqk += mq[i] * mk; mqk += (float4) (mq[i] * mk);
} }
// reduce the results from the threads in the simdgroup // reduce the results from the threads in the simdgroup
@ -2654,7 +2651,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
if (tiisg == 0) { if (tiisg == 0) {
half4 mm = mp4[ic/4 + cc]; float4 mm = (float4) mp4[ic/4 + cc];
mqk = mqk*scale + mm; mqk = mqk*scale + mm;
ss4[cc] = mqk; ss4[cc] = mqk;
@ -2666,13 +2663,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
{ {
const short p = tiisg; const short p = tiisg;
const half m = M; const float m = M;
const half s = ss[p]; const float s = ss[p];
M = simd_max(max(M, s)); M = simd_max(max(M, s));
const half ms = exp(m - M); const float ms = exp(m - M);
const half vs = exp(s - M); const float vs = exp(s - M);
S = S*ms + simd_sum(vs); S = S*ms + simd_sum(vs);
@ -2696,6 +2693,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
#pragma unroll #pragma unroll
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg; const short i = ii + tiisg;
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
@ -2724,18 +2722,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
// parallel reduce // parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) { for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) { if (sgitg < r) {
const half S0 = ss[ 0]; const float S0 = ss[ 0];
const half S1 = ss[r*SH + 0]; const float S1 = ss[r*SH + 0];
const half M0 = ss[ 1]; const float M0 = ss[ 1];
const half M1 = ss[r*SH + 1]; const float M1 = ss[r*SH + 1];
const half M = max(M0, M1); const float M = max(M0, M1);
const half ms0 = exp(M0 - M); const float ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M); const float ms1 = exp(M1 - M);
const half S = S0*ms0 + S1*ms1; const float S = S0*ms0 + S1*ms1;
if (tiisg == 0) { if (tiisg == 0) {
ss[0] = S; ss[0] = S;
@ -2756,7 +2754,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
// final rescale with 1/S and store to global memory // final rescale with 1/S and store to global memory
if (sgitg == 0) { if (sgitg == 0) {
const half S = ss[0]; const float S = ss[0];
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg; short i = ii + tiisg;

3
ggml.c
View file

@ -14882,12 +14882,13 @@ static void ggml_compute_forward_flash_attn_ext(
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (dst->op_params[1]) { switch (dst->op_params[1]) {
case GGML_PREC_DEFAULT: case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{ {
// uses F32 accumulators
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
} break; } break;
default: default:
{ {
// TODO: implement F32 precision
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
} }