latest kernel update, wrong values

This commit is contained in:
FSSRepo 2024-01-30 14:57:12 -05:00
parent 7980178a17
commit 3b0f74b428
2 changed files with 302 additions and 199 deletions

View file

@ -125,6 +125,11 @@
#include "ggml.h"
#include "ggml-backend-impl.h"
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CC_PASCAL 600
@ -679,7 +684,6 @@ static __device__ __forceinline__ half warp_reduce_max(half x) {
return x;
#else
(void) x;
bad_arch();
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
@ -6156,16 +6160,17 @@ static __global__ void flash_attn_f32(
#if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_bT;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
// based on metal version
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
static __global__ void flash_attn_ext_f16(
const char* __restrict__ q,
const char* __restrict__ k,
const char* __restrict__ v,
const char* __restrict__ mask,
float* __restrict__ kqv,
float* __restrict__ dst,
float scale,
int ne00,
int ne01,
@ -6190,57 +6195,64 @@ static __global__ void flash_attn_ext_f16(
const int warp_id = threadIdx.y;
const int lane_id = threadIdx.x;
const int n_warps = blockDim.y; // number of warps
const int num_warps = blockDim.y; // number of warps
const int iq3 = blockIdx.z;
const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q;
const int D2 = D/2;
const int N4 = WARP_SIZE;
const int L2 = (D2 + N4 - 1)/N4;
const int D16 = D/16;
const int Q16 = Q/16;
const int NW = WARP_SIZE;
const int SH = (C + D); // shared memory per simdgroup in (half)
const int T = D + n_warps*(D + 1*C); // shared memory size per query in half
const int T2 = T/2; // shared memory size per query in half2
const half scale_h = __float2half(scale);
const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2)
extern __shared__ half __flash_attn_f16_shmem[];
// pq
half * pq = (half *) (__flash_attn_f16_shmem + 0*D);
half2 * pq2 = (half2 *) (__flash_attn_f16_shmem + 0*D);
half * ps = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D);
half2 * ps2 = (half2 *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D);
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 2*D);
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
half16x16_acc lo[Q16][D16];
for (int i = 0; i < L2; ++i) {
// load heads from Q to shared memory
for (int j = warp_id; j < Q; j += n_warps) {
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
// load heads from Q to shared memory
for (int64_t j = warp_id; j < Q; j += num_warps) {
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
for (int64_t i = lane_id; i < D2; i += NW) {
if (iq1 + j < ne01) {
pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]);
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
} else {
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
sq2[j*T2 + i] = make_half2(0.0, 0.0);
}
}
// zero out shared memory
for (int j = 0; j < Q; ++j) {
ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
}
}
if (lane_id < C) {
for (int j = 0; j < Q; ++j) {
ss[j*T + 0 + lane_id] = 0.0;
// zero out lo
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
}
}
// zero out shared memory SH
for (int64_t j = 0; j < Q; ++j) {
for (int64_t i = lane_id; i < SH; i += NW) {
ss[j*T + i] = 0.0;
}
}
__syncthreads();
#if 0
{
half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros??
half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization
float S[Q];
float M[Q];
for(int i = 0; i < Q;i ++) {
S[i] = 0.0f;
M[i] = -INFINITY;
}
// assume K and V are same shape
const int ne22 = ne12;
@ -6265,162 +6277,252 @@ static __global__ void flash_attn_ext_f16(
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
// TODO: this can be improved
float * mp[Q];
{
const int ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
for (int j = 0; j < Q; ++j) {
if (iq1 + j < ne01) {
mp[j] = (float *)(mask + ((ir + j)%ne31) * nb31);
} else {
mp[j] = nullptr;
}
// load the queries from shared memory into local memory
half16x16_a mq[Q16][D16];
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
}
}
for (int iic = C*warp_id; iic < ne11; iic += C*n_warps) {
// skip -INF blocks
// TODO: double-check this
{
float smc = -INFINITY;
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
for (int j = 0; j < Q; ++j) {
const float mc = mp[j] ? mp[j][iic + lane_id] : -INFINITY;
smc = warp_reduce_max(max(smc, mc));
}
if (smc == -INFINITY) {
continue;
}
}
// pointer to the mask
const float * mp = (const float *) (mask + (ir%ne31)*nb31);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
// Q*K^T
{
half16x16_a mq;
half16x16_b mk;
half16x16_acc mqk;
for (int cc = 0; cc < C/16; ++cc) {
nvcuda::wmma::fill_fragment(mqk, 0);
const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
for(int i = 0; i < D16;i ++) {
nvcuda::wmma::load_matrix_sync(mq, pq + i*16, T);
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
nvcuda::wmma::mma_sync(mqk, mq, mk, mqk);
half16x16_acc mqk[Q16];
for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0);
}
nvcuda::wmma::store_matrix_sync(ss + 16*cc, mqk, T, nvcuda::wmma::mem_col_major);
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int64_t i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose
for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
}
}
// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) {
const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
int64_t msk_ne_row = nb31/sizeof(float);
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
int msk_col = i % 16;
int msk_row = i / 16;
mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]);
}
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major);
}
}
}
// used to detect blocks full of -INF
float smax = -INFINITY;
// online softmax
for (int j = 0; j < Q; ++j) {
const int p = lane_id;
if (C == 32) {
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
const float m = M[j];
const float s = __half2float(ss[j*T + p]);
half m = M[j];
smax = warp_reduce_max(max(smax, s));
M[j] = warp_reduce_max(max(M[j], s));
M[j] = warp_reduce_max(__hmax(M[j], s));
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
const half ms = __hisinf(m) == -1 ? 0.0 : hexp(m - M[j]);
const half vs = __hisinf(s) == -1 ? 0.0 : hexp(s - M[j]);
S[j] = S[j]*ms + warp_reduce_sum(vs);
S[j] = S[j]*ms + warp_reduce_sum(vs);
for (int i = 0; i < L2; ++i) {
ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms);
}
ss[j*T + p] = vs;
}
__syncthreads();
// (Q*K^T)*V
{
half16x16_acc mqkv;
half16x16_a mqk;
half16x16_b mv;
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(mqkv, 0);
for (int cc = 0; cc < C/16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
nvcuda::wmma::load_matrix_sync(mqk, ss + cc*16, T);
nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half));
nvcuda::wmma::mma_sync(mqkv, mqk, mv, mqkv);
// create a QxQ diagonal matrix for rescaling the output
if (p == j) {
ss[j*T + C + j] = __float2half(ms);
}
nvcuda::wmma::store_matrix_sync(ps + i*16, mqkv, T, nvcuda::wmma::mem_col_major);
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = __float2half(vs);
}
} else {
for (int64_t j = 0; j < Q; ++j) {
const float m = M[j];
for (int64_t p = lane_id; p < C; p += NW) {
const float s = __half2float(ss[j*T + p]);
smax = warp_reduce_max(max(smax, s));
M[j] = warp_reduce_max(max(M[j], s));
}
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
S[j] = S[j]*ms;
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = ms;
}
for (int64_t p = lane_id; p < C; p += NW) {
const float s = ss[j*T + p];
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
S[j] = S[j] + warp_reduce_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = __float2half(vs);
}
}
}
// skip -INF blocks
if (smax == -INFINITY) {
continue;
}
// O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) {
half16x16_a mm;
half16x16_b zro;
nvcuda::wmma::fill_fragment(zro, 0.0);
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
}
}
// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C/16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
for (int64_t i = 0; i < D16; ++i) {
half16x16_b mk;
nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half));
for (int64_t j = 0; j < Q16; ++j) {
half16x16_a mv;
nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T);
nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]);
}
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int64_t j = 0; j < Q; ++j) {
if (lane_id == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
ss[j*T + 0] = __float2half(S[j]);
ss[j*T + 1] = __float2half(M[j]);
}
}
}
__syncthreads();
// reduce the warps sequentially
for (int64_t sg = 1; sg < num_warps; ++sg) {
float S = 0.0f;
float M = -INFINITY;
// reduce the warps
// TODO: try parallel reduce
if (warp_id == 0) {
half S = 0.0;
half M = __float2half(-INFINITY);
__syncthreads();
for (int64_t sg = 1; sg < n_warps; ++sg) {
// each simdgroup stores its output to shared memory, reusing sq
if (warp_id == sg) {
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major);
}
}
}
__syncthreads();
// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
for (int64_t j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
const float S0 = __half2float(ss[j*T + 0]);
const float S1 = __half2float(ss[j*T + sg*SH + 0]);
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
const float M0 = __half2float(ss[j*T + 1]);
const float M1 = __half2float(ss[j*T + sg*SH + 1]);
M = __hmax(M0, M1);
M = max(M0, M1);
const half ms0 = hexp(M0 - M);
const half ms1 = hexp(M1 - M);
const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
S = S0*ms0 + S1*ms1;
if (lane_id == 0) {
ss[j*T + 0] = S;
ss[j*T + 1] = M;
}
ss[j*T + 0] = __float2half(S);
ss[j*T + 1] = __float2half(M);
for (int64_t i = 0; i < L2; ++i) {
ps2[j*T2 + N4*i + lane_id] = ps2[j*T2 + N4*i + lane_id]*__half2half2(ms0) + ps2[j*T2 + sg*(D + 1*C)/4 + N4*i + lane_id]*__half2half2(ms1);
ss[j*T + C + j ] = __float2half(ms0);
ss[j*T + C + j + sg*SH] = __float2half(ms1);
}
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (int64_t j = 0; j < Q16; ++j) {
half16x16_a ms0;
half16x16_a ms1;
half16x16_b t;
half16x16_acc t2;
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
// t <- lo
for (uint32_t k = 0; k < t.num_elements; k++) {
t.x[k] = lo[j][i].x[k];
}
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
}
}
}
}
__syncthreads();
float2 * dst2 = (float2 *) kqv;
// store result to shared memory (reuse sq)
if (warp_id == 0) {
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
half2 S = __half2half2(ss[j*T + 0]);
for (int i = 0; i < L2; ++i) {
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + N4*i + lane_id] = __half22float2(ps2[j*T2 + N4*i + lane_id]/S);
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major);
}
}
}
float2 * dst2 = (float2 *) dst;
// final rescale with 1/S and store to global memory
if (warp_id == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const float S = __half2float(ss[j*T + 0]);
for (int64_t i = lane_id; i < D2; i += NW) {
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]);
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S;
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S;
}
}
}
#endif
}
#else
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
@ -6451,7 +6553,6 @@ static __global__ void flash_attn_ext_f16(
int ne1,
int ne2,
int ne3) {
bad_arch();
}
#endif
@ -10446,9 +10547,9 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
const int nwarps = Q->ne[1] < 4 ? 12 : 4;
const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values)
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);
@ -10457,6 +10558,23 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
switch (Q->ne[0])
{
case 16:
flash_attn_ext_f16<16, 16, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1],
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 64:
flash_attn_ext_f16<64, 16, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (