fix equivalent fp16 math functions, compiler error 'undefined'
This commit is contained in:
parent
6374bc5779
commit
6416821499
1 changed files with 61 additions and 25 deletions
86
ggml-cuda.cu
86
ggml-cuda.cu
|
@ -6114,6 +6114,42 @@ struct __align__(8) half4 {
|
|||
half w;
|
||||
};
|
||||
|
||||
__device__ half4 make_half4(half x) {
|
||||
half4 t;
|
||||
t.x = x; t.y = x; t.z = x; t.w = x;
|
||||
return t;
|
||||
}
|
||||
|
||||
__device__ half4 __h4fma(half4 a, half b, half4 c) {
|
||||
half4 t;
|
||||
t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w);
|
||||
return t;
|
||||
}
|
||||
|
||||
__device__ half4 __h4fma(half4 a, half4 b, half4 c) {
|
||||
half4 t;
|
||||
t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w);
|
||||
return t;
|
||||
}
|
||||
|
||||
__device__ half4 __h4mul(half4 a, half b) {
|
||||
half4 t;
|
||||
t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b);
|
||||
return t;
|
||||
}
|
||||
|
||||
__device__ half4 __h4mul(half4 a, half4 b) {
|
||||
half4 t;
|
||||
t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w);
|
||||
return t;
|
||||
}
|
||||
|
||||
__device__ half4 __h4div(half4 a, half b) {
|
||||
half4 t;
|
||||
t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b);
|
||||
return t;
|
||||
}
|
||||
|
||||
// based on metal version
|
||||
template<int D, int R> // head size, rows per block
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
|
@ -6166,12 +6202,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
// kv indices
|
||||
const int ik2 = iq2 / rk2;
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
const int iv3 = iq3 / rv3;
|
||||
// const int iv2 = iq2 / rv2;
|
||||
// const int iv3 = iq3 / rv3;
|
||||
|
||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||
|
||||
const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr;
|
||||
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
|
||||
|
||||
extern __shared__ char shmem__[];
|
||||
|
||||
|
@ -6187,30 +6223,30 @@ static __global__ void flash_attn_ext_f16(
|
|||
// load R heads from Q to shared memory
|
||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||
if (warp_id == 0) {
|
||||
pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih];
|
||||
pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
||||
}
|
||||
|
||||
ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
|
||||
ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
half S = 0.0h;
|
||||
half M = -INFINITY;
|
||||
half S(0.0);
|
||||
half M(-INFINITY);
|
||||
|
||||
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
|
||||
const half mv = mp ? mp[ic] : 0.0h;
|
||||
if (mv == -INFINITY) {
|
||||
const half mv = mp ? mp[ic] : 0.0;
|
||||
if (__hisinf(mv) == -1) { // mv == -INFINITY
|
||||
continue;
|
||||
}
|
||||
|
||||
const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
|
||||
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K
|
||||
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K
|
||||
|
||||
half4 s4 = 0.0h;
|
||||
half4 s4 = make_half4(0.0);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D4/tph; ++i) {
|
||||
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
|
||||
s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4);
|
||||
}
|
||||
|
||||
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
|
||||
|
@ -6218,23 +6254,23 @@ static __global__ void flash_attn_ext_f16(
|
|||
__syncthreads();
|
||||
|
||||
if (tiih == 0) {
|
||||
half s = 0.0h;
|
||||
half s = 0.0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tph; ++i) {
|
||||
s += ss[hiisg*tph + i];
|
||||
}
|
||||
|
||||
s = s*scale + mv;
|
||||
s = __hfma(s, __float2half(scale), mv); // s*scale + mv
|
||||
|
||||
const half m = M;
|
||||
|
||||
M = max(M, s);
|
||||
M = __hmax(M, s);
|
||||
|
||||
const half ms = exp(m - M);
|
||||
const half vs = exp(s - M);
|
||||
const half ms = hexp(m - M);
|
||||
const half vs = hexp(s - M);
|
||||
|
||||
S = S*ms + vs;
|
||||
S = __hfma(S, ms, vs);
|
||||
|
||||
ss[2*hiisg + 0] = ms;
|
||||
ss[2*hiisg + 1] = vs;
|
||||
|
@ -6247,7 +6283,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D4/tph; ++i) {
|
||||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
|
||||
ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6267,12 +6303,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
const half M0 = ss[ 2*hiisg + 1];
|
||||
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
M = __hmax(M0, M1);
|
||||
|
||||
const half ms0 = exp(M0 - M);
|
||||
const half ms1 = exp(M1 - M);
|
||||
const half ms0 = hexp(M0 - M);
|
||||
const half ms1 = hexp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
S = __hfma(S0, ms0, __hmul(S1, ms1));
|
||||
|
||||
if (tiih == 0) {
|
||||
ss[2*hiisg + 0] = S;
|
||||
|
@ -6280,12 +6316,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
for (int i = 0; i < D4/tph; ++i) {
|
||||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
|
||||
ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < D4/tph; ++i) {
|
||||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
|
||||
ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue