diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 19108044e..4b51f1b74 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -29,14 +29,17 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } -template // D == head size -__launch_bounds__(D, 1) +#define FATTN_KQ_STRIDE 256 + +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -60,20 +63,25 @@ static __global__ void flash_attn_vec_ext_f16( const int ne3) { //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne31*blockIdx.x; + const half * maskh = (const half *) mask; + + if (parallel_blocks == 1) { + Q_f2 += blockIdx.x*nb01/sizeof(float2); + maskh += blockIdx.x*ne11; + } const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); - constexpr int nwarps = D/WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < nwarps*WARP_SIZE); - __shared__ half KQ[D]; - KQ[tid] = 0.0f; + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; half kqmax = -INFINITY; @@ -85,7 +93,6 @@ static __global__ void flash_attn_vec_ext_f16( kqmax_shared[threadIdx.x] = -INFINITY; kqsum_shared[threadIdx.x] = 0.0f; } - __syncthreads(); // Convert Q to half2 and store in registers: @@ -102,14 +109,15 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -153,19 +161,25 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); + if (tid < D) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; - VKQ += V_k*KQ2[k0/2]; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; + } } } + if (tid >= D) { + kqsum = 0.0f; + } + kqsum = warp_reduce_sum(kqsum); if (threadIdx.x == 0) { kqsum_shared[threadIdx.y] = kqsum; @@ -174,12 +188,22 @@ static __global__ void flash_attn_vec_ext_f16( kqsum = kqsum_shared[threadIdx.x]; kqsum = warp_reduce_sum(kqsum); - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; + if (tid >= D) { + return; + } + + if (parallel_blocks == 1) { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; + } else { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); + + if (tid == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); + } + } } -#define FATTN_KQ_STRIDE 256 - -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -187,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -228,10 +253,15 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; + const half2 * mask2 = (half2 *) mask; + + if (parallel_blocks == 1) { + Q_f += blockIdx.x * ncols*nb01/sizeof(float); + mask2 += blockIdx.x * ncols*ne11/2; + } const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -273,7 +303,11 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + if (parallel_blocks == 1) { + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } else { + KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } } } @@ -291,7 +325,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -420,22 +455,75 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } + if (parallel_blocks == 1) { #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; + } + } + return; + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { return; } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; - } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; } + + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const half2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half2 meta[parallel_blocks]; + if (tid < parallel_blocks) { + meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; + } + + __syncthreads(); + + half kqmax = __low2half(meta[0]); +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = __hmax(kqmax, __low2half(meta[l])); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } constexpr int get_max_power_of_2(int x) { @@ -462,26 +550,26 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, nullptr, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -508,88 +596,39 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { - const int nwarps = Q->ne[0] / WARP_SIZE; - const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + if (Q->ne[1] == 1) { + constexpr int parallel_blocks = 4; + + ggml_cuda_pool_alloc dst_tmp(ctx.pool()); + ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); + + const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; + + // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: + constexpr int nwarps_tc = 4; + constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); + + const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); + const dim3 block_dim_combine(Q->ne[0], 1, 1); + const int shmem_combine = 0; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + switch (Q->ne[0]) { - // case 64: - // flash_attn_vec_ext_f16<64> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 80: - // flash_attn_vec_ext_f16<80> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 96: - // flash_attn_vec_ext_f16<96> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 112: - // flash_attn_vec_ext_f16<112> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - case 128: - flash_attn_vec_ext_f16<128> + case 64: + flash_attn_vec_ext_f16<64, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -598,15 +637,118 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<64, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 80: + flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<80, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 96: + flash_attn_vec_ext_f16<96, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<96, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 112: + flash_attn_vec_ext_f16<112, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<112, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 128: + flash_attn_vec_ext_f16<128, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<128, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; case 256: - flash_attn_vec_ext_f16<256> + flash_attn_vec_ext_f16<256, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -615,6 +757,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<256, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; default: GGML_ASSERT(false); @@ -633,7 +782,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; - const int nwarps = 4; + constexpr int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const size_t shmem = 0;