diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 11adbabd6..2cf6c8d98 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -141,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b0149b7be..989780dbc 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -390,6 +390,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL +#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -403,6 +408,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 91ef5551e..5f1345a7f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -36,18 +36,17 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask; - - if (parallel_blocks == 1) { - Q_f2 += blockIdx.x*nb01/sizeof(float2); - maskh += blockIdx.x*ne11; - } + const half * maskh = (const half *) mask + ne11*ic; const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); @@ -85,7 +84,7 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + const int k_start = parallel_blocks == 1 ? 0 : ip*D; for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; @@ -168,18 +167,19 @@ static __global__ void flash_attn_vec_ext_f16( return; } + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); if (parallel_blocks == 1) { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; - } else { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); - - if (tid == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); - } + dst_val /= kqsum; } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; + + if (parallel_blocks == 1 || tid != 0) { + return; + } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -212,8 +212,12 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; @@ -233,15 +237,10 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask; - - if (parallel_blocks == 1) { - Q_f += blockIdx.x * ncols*nb01/sizeof(float); - mask2 += blockIdx.x * ncols*ne11/2; - } + const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -283,11 +282,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - if (parallel_blocks == 1) { - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } else { - KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -305,8 +300,7 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -439,41 +433,39 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } - if (parallel_blocks == 1) { #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; - } + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; } - } else { + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + half dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; } + + half2 dst_meta_val = KQ_max[j0/nwarps]; + reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#endif // FP16_MMA_AVAILABLE } template // D == head size @@ -482,7 +474,10 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -513,7 +508,7 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } constexpr int get_max_power_of_2(int x) { @@ -540,26 +535,124 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, nullptr, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -583,259 +676,106 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - const cudaStream_t main_stream = ctx.stream(); - - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); - - if (Q->ne[1] == 1) { + if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { constexpr int parallel_blocks = 4; - - ggml_cuda_pool_alloc dst_tmp(ctx.pool()); - ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); - - const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; - const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const int shmem = 0; - - // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: - constexpr int nwarps_tc = 4; - constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); - - const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); - const dim3 block_dim_combine(Q->ne[0], 1, 1); - const int shmem_combine = 0; - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - switch (Q->ne[0]) { case 64: - flash_attn_vec_ext_f16<64, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<64, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 80: - flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<80, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 96: - flash_attn_vec_ext_f16<96, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<96, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 112: - flash_attn_vec_ext_f16<112, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<112, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 128: - flash_attn_vec_ext_f16<128, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<128, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 256: - flash_attn_vec_ext_f16<256, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<256, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); return; } - int cols_per_block; - if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } - constexpr int nwarps = 4; - const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const size_t shmem = 0; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; switch (Q->ne[0]) { - case 64: switch (cols_per_block) { - FATTN_SWITCH_CASE(64, 8, nwarps); - FATTN_SWITCH_CASE(64, 16, nwarps); - FATTN_SWITCH_CASE(64, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); - break; - } break; - case 80: switch (cols_per_block) { - // FATTN_SWITCH_CASE(80, 8, nwarps); - FATTN_SWITCH_CASE(80, 16, nwarps); - FATTN_SWITCH_CASE(80, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); - break; - } break; - case 96: switch (cols_per_block) { - FATTN_SWITCH_CASE(96, 8, nwarps); - FATTN_SWITCH_CASE(96, 16, nwarps); - FATTN_SWITCH_CASE(96, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); - break; - } break; - case 112: switch (cols_per_block) { - // FATTN_SWITCH_CASE(112, 8, nwarps); - FATTN_SWITCH_CASE(112, 16, nwarps); - FATTN_SWITCH_CASE(112, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); - break; - } break; - case 128: switch (cols_per_block) { - FATTN_SWITCH_CASE(128, 8, nwarps); - FATTN_SWITCH_CASE(128, 16, nwarps); - FATTN_SWITCH_CASE(128, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); - break; - } break; - case 256: switch (cols_per_block) { - FATTN_SWITCH_CASE(256, 8, nwarps); - FATTN_SWITCH_CASE(256, 16, nwarps); - FATTN_SWITCH_CASE(256, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); - break; - } break; + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); + return; }