fix kernel

This commit is contained in:
FSSRepo 2024-01-31 12:28:48 -05:00
parent 3b0f74b428
commit b1479dfbc5
2 changed files with 56 additions and 49 deletions

View file

@ -6158,9 +6158,9 @@ static __global__ void flash_attn_f32(
} }
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_bT; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc; typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
// based on metal version // based on metal version
@ -6204,15 +6204,15 @@ static __global__ void flash_attn_ext_f16(
const int D16 = D/16; const int D16 = D/16;
const int Q16 = Q/16; const int Q16 = Q/16;
const int NW = WARP_SIZE; const int NW = WARP_SIZE;
const int SH = (C + D); // shared memory per simdgroup in (half) const int SH = (C + Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half) const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2) const int T2 = T/2; // shared memory size per query in (half2)
extern __shared__ half __flash_attn_f16_shmem[]; extern __shared__ half __flash_attn_f16_shmem[];
// pq // pq
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
half16x16_acc lo[Q16][D16]; half16x16_acc lo[Q16][D16];
@ -6249,7 +6249,7 @@ static __global__ void flash_attn_ext_f16(
float S[Q]; float S[Q];
float M[Q]; float M[Q];
for(int i = 0; i < Q;i ++) { for(int i = 0; i < Q; i++) {
S[i] = 0.0f; S[i] = 0.0f;
M[i] = -INFINITY; M[i] = -INFINITY;
} }
@ -6288,7 +6288,7 @@ static __global__ void flash_attn_ext_f16(
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
// pointer to the mask // pointer to the mask
const float * mp = (const float *) (mask + (ir%ne31)*nb31); const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr;
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
@ -6305,7 +6305,7 @@ static __global__ void flash_attn_ext_f16(
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
@ -6314,14 +6314,14 @@ static __global__ void flash_attn_ext_f16(
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; // const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
int64_t msk_ne_row = nb31/sizeof(float); // int64_t msk_ne_row = nb31/sizeof(float);
for (uint32_t i = 0; i < mqk[j].num_elements; i++) { for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
int msk_col = i % 16; // int msk_col = i % 16;
int msk_row = i / 16; // int msk_row = i / 16;
mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]); mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]);
} }
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
} }
} }
} }
@ -6370,11 +6370,11 @@ static __global__ void flash_attn_ext_f16(
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) { if (lane_id == j) {
ss[j*T + C + j] = ms; ss[j*T + C + j] = __float2half(ms);
} }
for (int64_t p = lane_id; p < C; p += NW) { for (int64_t p = lane_id; p < C; p += NW) {
const float s = ss[j*T + p]; const float s = __half2float(ss[j*T + p]);
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
@ -6393,14 +6393,18 @@ static __global__ void flash_attn_ext_f16(
// O = diag(ms)*O // O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
half16x16_a mm; // half16x16_a mm;
half16x16_b zro; // half16x16_b zro;
nvcuda::wmma::fill_fragment(zro, 0.0); // nvcuda::wmma::fill_fragment(zro, 0.0);
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); //nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
for (uint32_t k = 0; k < 16*16; k++) {
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
lo[j][i].x[k] = tmp * lo[j][i].x[k];
}
} }
} }
@ -6444,7 +6448,7 @@ static __global__ void flash_attn_ext_f16(
if (warp_id == sg) { if (warp_id == sg) {
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
} }
} }
} }
@ -6487,13 +6491,13 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2); nvcuda::wmma::mma_sync(t2, ms1, t, t2);
// store temporally 'lo' data
// t <- lo nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
for (uint32_t k = 0; k < t.num_elements; k++) { // load 'lo' data into t
t.x[k] = lo[j][i].x[k]; nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
}
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
} }
} }
@ -6504,22 +6508,20 @@ static __global__ void flash_attn_ext_f16(
if (warp_id == 0) { if (warp_id == 0) {
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
} }
} }
} }
float2 * dst2 = (float2 *) dst; // float2 * dst2 = (float2 *) dst;
// final rescale with 1/S and store to global memory // final rescale with 1/S and store to global memory
if (warp_id == 0) { if (warp_id == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const float S = __half2float(ss[j*T + 0]); const float S = __half2float(ss[j*T + 0]);
for (int64_t i = lane_id; i < D2; i += NW) { for (int64_t i = lane_id; i < D; i += NW) {
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]); dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S;
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S;
} }
} }
} }
@ -10526,13 +10528,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(mask->type == GGML_TYPE_F32); if(mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F32);
}
GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
GGML_ASSERT(K->backend == GGML_BACKEND_GPU); GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
GGML_ASSERT(V->backend == GGML_BACKEND_GPU); GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); if(mask) {
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
}
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
ggml_cuda_set_device(g_main_device); ggml_cuda_set_device(g_main_device);
@ -10541,7 +10547,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra; ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
float scale; float scale;
@ -10549,13 +10555,14 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
const int nqpb = 16; // queries per block const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values) const int ncpw = 32; // cache values per warp (does not work for other values)
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
const int nwarps = 1;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1); dim3 block_dim(32, nwarps, 1);
int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2); int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale);
switch (Q->ne[0]) switch (Q->ne[0])
{ {
case 16: case 16:
@ -10564,12 +10571,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value (const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst (float *) dst_extra->data_device[g_main_device], // dst
scale, scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3], K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
@ -10581,12 +10588,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value (const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst (float *) dst_extra->data_device[g_main_device], // dst
scale, scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3], K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
@ -10598,12 +10605,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value (const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst (float *) dst_extra->data_device[g_main_device], // dst
scale, scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3], K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
@ -10615,12 +10622,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value (const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst (float *) dst_extra->data_device[g_main_device], // dst
scale, scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3], K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]

View file

@ -201,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_cgraph * gf = ggml_new_graph(ctx0);
if(!model.naive_attn) { if(!model.naive_attn) {
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0])); struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0]));
ggml_build_forward_expand(gf, result); ggml_build_forward_expand(gf, result);
} else { } else {
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);