diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 93fb7e80d..8bad28064 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7235,34 +7235,35 @@ typedef nvcuda::wmma::fragment #endif // based on metal version -template // D head size, Q queries per block, C cache items per block +template // D head size, Q queries per block, C cache items per block +__launch_bounds__(8*WARP_SIZE, 1) // tells the compiler to avoid register spilling even if it reduces occupancy static __global__ void flash_attn_ext_f16( - const char* __restrict__ q, - const char* __restrict__ k, - const char* __restrict__ v, - const char* __restrict__ mask, - float* __restrict__ dst, - float scale, - int ne00, - int ne01, - int ne02, - int ne03, - int ne10, - int ne11, - int ne12, - int ne13, - int ne31, - int nb31, - int nb01, - int nb02, - int nb03, - int nb11, - int nb12, - int nb13, - int ne0, - int ne1, - int ne2, - int ne3) { + const char * __restrict__ q, + const char * __restrict__ k, + const char * __restrict__ v, + const char * __restrict__ mask, + float * __restrict__ dst, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { #if __CUDA_ARCH__ >= CC_VOLTA const int warp_id = threadIdx.y; const int lane_id = threadIdx.x; @@ -7319,24 +7320,28 @@ static __global__ void flash_attn_ext_f16( } } - nvcuda::wmma::fill_fragment(zr, 0.0); + nvcuda::wmma::fill_fragment(zr, 0.0f); // zero out lo +#pragma unroll for (int j = 0; j < Q16; ++j) { +#pragma unroll for (int i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::fill_fragment(lo[j][i], 0.0f); } } // zero out shared memory SH +#pragma unroll for (int j = 0; j < Q; ++j) { +#pragma unroll for (int i0 = 0; i0 < SH; i0 += NW) { const int i = i0 + lane_id; if (i >= SH) { break; } - ss[j*T + i] = 0.0; + ss[j*T + i] = 0.0f; } } @@ -7346,6 +7351,7 @@ static __global__ void flash_attn_ext_f16( half S = __float2half(0.0f); half M[Q]; +#pragma unroll for (int i = 0; i < Q; ++i) { M[i] = CUDART_MIN_DENORM_FP16; } @@ -7375,17 +7381,20 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; +#pragma unroll for (int j = 0; j < Q16; ++j) { +#pragma unroll for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } // pointer to the mask - const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + const half * mp = use_mask ? (const half *) (mask + iq1*nb31) : nullptr; // prepare diagonal scale matrix half16x16_b mscale; +#pragma unroll for (int i = 0; i < 16; ++i) { ss[i*T + i] = __float2half(scale); } @@ -7404,27 +7413,31 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int cc = 0; cc < C16; ++cc) { half16x16_acc mqk[Q16]; +#pragma unroll for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::fill_fragment(mqk[j], 0); + nvcuda::wmma::fill_fragment(mqk[j], 0.0f); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); +#pragma unroll for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); +#pragma unroll for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask +#pragma unroll for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; - if (mp) { + if (use_mask) { nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } @@ -7432,7 +7445,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); - nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, use_mask ? mm : zr); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } @@ -7442,9 +7455,11 @@ static __global__ void flash_attn_ext_f16( half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax +#pragma unroll for (int j = 0; j < Q; ++j) { const half m = M[j]; +#pragma unroll for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; @@ -7460,6 +7475,7 @@ static __global__ void flash_attn_ext_f16( half2 ls = make_half2(0.0f, 0.0f); half2 M2 = make_half2(M[j], M[j]); +#pragma unroll for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; @@ -7493,12 +7509,14 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O +#pragma unroll for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); +#pragma unroll for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); @@ -7509,26 +7527,32 @@ static __global__ void flash_attn_ext_f16( } // restore zeros +#pragma unroll for (int j = 0; j < Q16; ++j) { nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } // O = O + (Q*K^T)*V { +#pragma unroll for (int cc = 0; cc < C16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mv[D16]; +#pragma unroll for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); } half16x16_a ms[Q16]; +#pragma unroll for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); } +#pragma unroll for (int j = 0; j < Q16; ++j) { +#pragma unroll for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); } @@ -7545,12 +7569,15 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially +#pragma unroll for (int sg = 1; sg < num_warps; ++sg) { __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { +#pragma unroll for (int j = 0; j < Q16; ++j) { +#pragma unroll for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } @@ -7561,6 +7588,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { +#pragma unroll for (int j = lane_id; j < Q; j += NW) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -7583,6 +7611,7 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 +#pragma unroll for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; @@ -7592,6 +7621,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); +#pragma unroll for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, zr); @@ -7608,7 +7638,9 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { +#pragma unroll for (int j = 0; j < Q16; ++j) { +#pragma unroll for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } @@ -7617,9 +7649,11 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { +#pragma unroll for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; +#pragma unroll for (int i0 = 0; i0 < D; i0 += NW) { const int i = i0 + lane_id; if (i >= D) { @@ -11927,9 +11961,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * //const size_t shmem_max = 96*1024; //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + GGML_ASSERT(mask); // FIXME case without mask switch (Q->ne[0]) { case 64: - flash_attn_ext_f16<64, NQPB, NCPW> + flash_attn_ext_f16<64, NQPB, NCPW, true> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -11946,7 +11981,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, NQPB, NCPW> + flash_attn_ext_f16<80, NQPB, NCPW, true> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -11963,7 +11998,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 96: - flash_attn_ext_f16<96, NQPB, NCPW> + flash_attn_ext_f16<96, NQPB, NCPW, true> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -11980,7 +12015,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 112: - flash_attn_ext_f16<112, NQPB, NCPW> + flash_attn_ext_f16<112, NQPB, NCPW, true> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -11997,7 +12032,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, NQPB, NCPW> + flash_attn_ext_f16<128, NQPB, NCPW, true> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -12014,7 +12049,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 256: - flash_attn_ext_f16<256, NQPB, NCPW> + flash_attn_ext_f16<256, NQPB, NCPW, true> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -12031,6 +12066,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; default: + GGML_ASSERT(false); break; }