diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 940ffbfc8..9d2b99ac9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6114,6 +6114,42 @@ struct __align__(8) half4 { half w; }; +__device__ half4 make_half4(half x) { + half4 t; + t.x = x; t.y = x; t.z = x; t.w = x; + return t; +} + +__device__ half4 __h4fma(half4 a, half b, half4 c) { + half4 t; + t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w); + return t; +} + +__device__ half4 __h4fma(half4 a, half4 b, half4 c) { + half4 t; + t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w); + return t; +} + +__device__ half4 __h4mul(half4 a, half b) { + half4 t; + t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b); + return t; +} + +__device__ half4 __h4mul(half4 a, half4 b) { + half4 t; + t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w); + return t; +} + +__device__ half4 __h4div(half4 a, half b) { + half4 t; + t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b); + return t; +} + // based on metal version template // head size, rows per block static __global__ void flash_attn_ext_f16( @@ -6166,12 +6202,12 @@ static __global__ void flash_attn_ext_f16( // kv indices const int ik2 = iq2 / rk2; const int ik3 = iq3 / rk3; - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // const int iv2 = iq2 / rv2; + // const int iv3 = iq3 / rv3; const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr; + const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; extern __shared__ char shmem__[]; @@ -6187,30 +6223,30 @@ static __global__ void flash_attn_ext_f16( // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { if (warp_id == 0) { - pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih]; + pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; } - ps4[hiisg*D4 + tph*i + tiih] = 0.0h; + ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0); } __syncthreads(); - half S = 0.0h; - half M = -INFINITY; + half S(0.0); + half M(-INFINITY); for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { - const half mv = mp ? mp[ic] : 0.0h; - if (mv == -INFINITY) { + const half mv = mp ? mp[ic] : 0.0; + if (__hisinf(mv) == -1) { // mv == -INFINITY continue; } const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K + const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K - half4 s4 = 0.0h; + half4 s4 = make_half4(0.0); #pragma unroll for (int i = 0; i < D4/tph; ++i) { - s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; + s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4); } ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); @@ -6218,23 +6254,23 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); if (tiih == 0) { - half s = 0.0h; + half s = 0.0; #pragma unroll for (int i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } - s = s*scale + mv; + s = __hfma(s, __float2half(scale), mv); // s*scale + mv const half m = M; - M = max(M, s); + M = __hmax(M, s); - const half ms = exp(m - M); - const half vs = exp(s - M); + const half ms = hexp(m - M); + const half vs = hexp(s - M); - S = S*ms + vs; + S = __hfma(S, ms, vs); ss[2*hiisg + 0] = ms; ss[2*hiisg + 1] = vs; @@ -6247,7 +6283,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs)); } } @@ -6267,12 +6303,12 @@ static __global__ void flash_attn_ext_f16( const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; - M = max(M0, M1); + M = __hmax(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + S = __hfma(S0, ms0, __hmul(S1, ms1)); if (tiih == 0) { ss[2*hiisg + 0] = S; @@ -6280,12 +6316,12 @@ static __global__ void flash_attn_ext_f16( } for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1)); } } for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; + ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S); } }