From b1479dfbc574dc2b0ea8a7426f44011f73a118fc Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 31 Jan 2024 12:28:48 -0500 Subject: [PATCH] fix kernel --- ggml-cuda.cu | 103 ++++++++++++++++++--------------- tests/test-flash-attention.cpp | 2 +- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5229e15d2..fe24935a4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6158,9 +6158,9 @@ static __global__ void flash_attn_f32( } #if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; typedef nvcuda::wmma::fragment half16x16_acc; // based on metal version @@ -6204,15 +6204,15 @@ static __global__ void flash_attn_ext_f16( const int D16 = D/16; const int Q16 = Q/16; const int NW = WARP_SIZE; - const int SH = (C + D); // shared memory per simdgroup in (half) + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) extern __shared__ half __flash_attn_f16_shmem[]; // pq - half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data - half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix half16x16_acc lo[Q16][D16]; @@ -6249,7 +6249,7 @@ static __global__ void flash_attn_ext_f16( float S[Q]; float M[Q]; - for(int i = 0; i < Q;i ++) { + for(int i = 0; i < Q; i++) { S[i] = 0.0f; M[i] = -INFINITY; } @@ -6288,7 +6288,7 @@ static __global__ void flash_attn_ext_f16( const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; // pointer to the mask - const float * mp = (const float *) (mask + (ir%ne31)*nb31); + const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr; // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -6305,7 +6305,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); for (int64_t j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); @@ -6314,14 +6314,14 @@ static __global__ void flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q16; ++j) { - const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; - int64_t msk_ne_row = nb31/sizeof(float); + // const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; + // int64_t msk_ne_row = nb31/sizeof(float); for (uint32_t i = 0; i < mqk[j].num_elements; i++) { - int msk_col = i % 16; - int msk_row = i / 16; - mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]); + // int msk_col = i % 16; + // int msk_row = i / 16; + mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]); } - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } } @@ -6370,11 +6370,11 @@ static __global__ void flash_attn_ext_f16( // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { - ss[j*T + C + j] = ms; + ss[j*T + C + j] = __float2half(ms); } for (int64_t p = lane_id; p < C; p += NW) { - const float s = ss[j*T + p]; + const float s = __half2float(ss[j*T + p]); const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); @@ -6393,14 +6393,18 @@ static __global__ void flash_attn_ext_f16( // O = diag(ms)*O for (int64_t j = 0; j < Q16; ++j) { - half16x16_a mm; - half16x16_b zro; + // half16x16_a mm; + // half16x16_b zro; - nvcuda::wmma::fill_fragment(zro, 0.0); - nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + // nvcuda::wmma::fill_fragment(zro, 0.0); + // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); + //nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); + for (uint32_t k = 0; k < 16*16; k++) { + half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16]; + lo[j][i].x[k] = tmp * lo[j][i].x[k]; + } } } @@ -6444,7 +6448,7 @@ static __global__ void flash_attn_ext_f16( if (warp_id == sg) { for (int64_t j = 0; j < Q16; ++j) { for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } } @@ -6487,13 +6491,13 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, t2); - - // t <- lo - for (uint32_t k = 0; k < t.num_elements; k++) { - t.x[k] = lo[j][i].x[k]; - } + // store temporally 'lo' data + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + // load 'lo' data into t + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } } @@ -6504,22 +6508,20 @@ static __global__ void flash_attn_ext_f16( if (warp_id == 0) { for (int64_t j = 0; j < Q16; ++j) { for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } } - float2 * dst2 = (float2 *) dst; + // float2 * dst2 = (float2 *) dst; // final rescale with 1/S and store to global memory if (warp_id == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const float S = __half2float(ss[j*T + 0]); - for (int64_t i = lane_id; i < D2; i += NW) { - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]); - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S; - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S; + for (int64_t i = lane_id; i < D; i += NW) { + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S; } } } @@ -10526,13 +10528,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); - GGML_ASSERT(mask->type == GGML_TYPE_F32); + if(mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F32); + } GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); GGML_ASSERT(K->backend == GGML_BACKEND_GPU); GGML_ASSERT(V->backend == GGML_BACKEND_GPU); - GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); + if(mask) { + GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); + } GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); ggml_cuda_set_device(g_main_device); @@ -10541,7 +10547,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; - ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra; + ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; float scale; @@ -10549,13 +10555,14 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) - const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; + // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; + const int nwarps = 1; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); - int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2); - printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); + int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale); switch (Q->ne[0]) { case 16: @@ -10564,12 +10571,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] @@ -10581,12 +10588,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] @@ -10598,12 +10605,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] @@ -10615,12 +10622,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index 5d83eeabd..d4457a53e 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -201,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_cgraph * gf = ggml_new_graph(ctx0); if(!model.naive_attn) { - struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0])); + struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0])); ggml_build_forward_expand(gf, result); } else { struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);