fix equivalent fp16 math functions, compiler error 'undefined'

This commit is contained in:
FSSRepo 2024-01-24 10:57:05 -05:00
parent 6374bc5779
commit 6416821499

View file

@ -6114,6 +6114,42 @@ struct __align__(8) half4 {
half w; half w;
}; };
__device__ half4 make_half4(half x) {
half4 t;
t.x = x; t.y = x; t.z = x; t.w = x;
return t;
}
__device__ half4 __h4fma(half4 a, half b, half4 c) {
half4 t;
t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w);
return t;
}
__device__ half4 __h4fma(half4 a, half4 b, half4 c) {
half4 t;
t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w);
return t;
}
__device__ half4 __h4mul(half4 a, half b) {
half4 t;
t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b);
return t;
}
__device__ half4 __h4mul(half4 a, half4 b) {
half4 t;
t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w);
return t;
}
__device__ half4 __h4div(half4 a, half b) {
half4 t;
t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b);
return t;
}
// based on metal version // based on metal version
template<int D, int R> // head size, rows per block template<int D, int R> // head size, rows per block
static __global__ void flash_attn_ext_f16( static __global__ void flash_attn_ext_f16(
@ -6166,12 +6202,12 @@ static __global__ void flash_attn_ext_f16(
// kv indices // kv indices
const int ik2 = iq2 / rk2; const int ik2 = iq2 / rk2;
const int ik3 = iq3 / rk3; const int ik3 = iq3 / rk3;
const int iv2 = iq2 / rv2; // const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3; // const int iv3 = iq3 / rv3;
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr; const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
extern __shared__ char shmem__[]; extern __shared__ char shmem__[];
@ -6187,30 +6223,30 @@ static __global__ void flash_attn_ext_f16(
// load R heads from Q to shared memory // load R heads from Q to shared memory
for (int64_t i = 0; i < D4/tph; ++i) { for (int64_t i = 0; i < D4/tph; ++i) {
if (warp_id == 0) { if (warp_id == 0) {
pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih]; pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
} }
ps4[hiisg*D4 + tph*i + tiih] = 0.0h; ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0);
} }
__syncthreads(); __syncthreads();
half S = 0.0h; half S(0.0);
half M = -INFINITY; half M(-INFINITY);
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
const half mv = mp ? mp[ic] : 0.0h; const half mv = mp ? mp[ic] : 0.0;
if (mv == -INFINITY) { if (__hisinf(mv) == -1) { // mv == -INFINITY
continue; continue;
} }
const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K
half4 s4 = 0.0h; half4 s4 = make_half4(0.0);
#pragma unroll #pragma unroll
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D4/tph; ++i) {
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4);
} }
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
@ -6218,23 +6254,23 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
if (tiih == 0) { if (tiih == 0) {
half s = 0.0h; half s = 0.0;
#pragma unroll #pragma unroll
for (int i = 0; i < tph; ++i) { for (int i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i]; s += ss[hiisg*tph + i];
} }
s = s*scale + mv; s = __hfma(s, __float2half(scale), mv); // s*scale + mv
const half m = M; const half m = M;
M = max(M, s); M = __hmax(M, s);
const half ms = exp(m - M); const half ms = hexp(m - M);
const half vs = exp(s - M); const half vs = hexp(s - M);
S = S*ms + vs; S = __hfma(S, ms, vs);
ss[2*hiisg + 0] = ms; ss[2*hiisg + 0] = ms;
ss[2*hiisg + 1] = vs; ss[2*hiisg + 1] = vs;
@ -6247,7 +6283,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll #pragma unroll
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs));
} }
} }
@ -6267,12 +6303,12 @@ static __global__ void flash_attn_ext_f16(
const half M0 = ss[ 2*hiisg + 1]; const half M0 = ss[ 2*hiisg + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
M = max(M0, M1); M = __hmax(M0, M1);
const half ms0 = exp(M0 - M); const half ms0 = hexp(M0 - M);
const half ms1 = exp(M1 - M); const half ms1 = hexp(M1 - M);
S = S0*ms0 + S1*ms1; S = __hfma(S0, ms0, __hmul(S1, ms1));
if (tiih == 0) { if (tiih == 0) {
ss[2*hiisg + 0] = S; ss[2*hiisg + 0] = S;
@ -6280,12 +6316,12 @@ static __global__ void flash_attn_ext_f16(
} }
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1));
} }
} }
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S);
} }
} }