CUDA: refactor host code, dyn. par. blocks

This commit is contained in:
Johannes Gäßler 2024-04-09 11:39:16 +02:00
parent 5668c79ea0
commit 34f93bbb39
3 changed files with 258 additions and 311 deletions

View file

@ -36,18 +36,17 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#if FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y);
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask;
if (parallel_blocks == 1) {
Q_f2 += blockIdx.x*nb01/sizeof(float2);
maskh += blockIdx.x*ne11;
}
const half * maskh = (const half *) mask + ne11*ic;
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);
@ -85,7 +84,7 @@ static __global__ void flash_attn_vec_ext_f16(
half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D;
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new = kqmax;
@ -168,18 +167,19 @@ static __global__ void flash_attn_vec_ext_f16(
return;
}
half dst_val = (__low2half(VKQ) + __high2half(VKQ));
if (parallel_blocks == 1) {
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum;
} else {
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ));
if (tid == 0) {
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum);
}
dst_val /= kqsum;
}
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
if (parallel_blocks == 1 || tid != 0) {
return;
}
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum);
#else
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#endif // FP16_AVAILABLE
}
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
@ -212,8 +212,12 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#if FP16_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
@ -233,15 +237,10 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y);
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half2 * mask2 = (half2 *) mask;
if (parallel_blocks == 1) {
Q_f += blockIdx.x * ncols*nb01/sizeof(float);
mask2 += blockIdx.x * ncols*ne11/2;
}
const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
@ -283,11 +282,7 @@ static __global__ void flash_attn_ext_f16(
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
if (parallel_blocks == 1) {
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
} else {
KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
@ -305,8 +300,7 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
// Iterate over ne11 == previous tokens:
const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
@ -439,41 +433,39 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
}
if (parallel_blocks == 1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (ncols*blockIdx.x + j >= ne01) {
return;
}
const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]);
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
}
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
} else {
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]);
#pragma unroll
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
if (i0 + nwarps*WARP_SIZE > D && i >= D) {
return;
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i];
half dst_val = VKQ[j_VKQ*D_padded + i];
if (parallel_blocks == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
}
if (threadIdx.y == 0 && threadIdx.x == 0) {
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
if (parallel_blocks == 1 || threadIdx.x != 0) {
continue;
}
half2 dst_meta_val = KQ_max[j0/nwarps];
reinterpret_cast<half&>(dst_meta_val.y) = KQ_rowsum_j;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#endif // FP16_MMA_AVAILABLE
}
template<int D, int parallel_blocks> // D == head size
@ -482,7 +474,10 @@ static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const half2 * __restrict__ VKQ_meta,
float * __restrict__ dst) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#if FP16_AVAILABLE
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
dst += D * gridDim.y*blockIdx.x;
const int tid = threadIdx.x;
__builtin_assume(tid < D);
@ -513,7 +508,7 @@ static __global__ void flash_attn_combine_results(
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
#else
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#endif // FP16_AVAILABLE
}
constexpr int get_max_power_of_2(int x) {
@ -540,26 +535,124 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
#define FATTN_SWITCH_CASE(D, ncols, nwarps) \
case ncols: { \
constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \
flash_attn_ext_f16<D, ncols, nwarps, get_VKQ_stride(D, nwarps, frag_m), 1> \
<<<blocks_num, block_dim, shmem, main_stream>>> ( \
(const char *) Q->data, \
(const char *) K->data, \
(const char *) V->data, \
mask ? ((const char *) mask->data) : nullptr, \
(float *) KQV->data, nullptr, \
scale, \
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \
K->ne[0], K->ne[1], K->ne[2], K->ne[3], \
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \
Q->nb[1], Q->nb[2], Q->nb[3], \
K->nb[1], K->nb[2], K->nb[3], \
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \
); \
} \
break; \
template <int D, int parallel_blocks> void launch_fattn_vec_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool);
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE;
constexpr dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
flash_attn_vec_ext_f16<D, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
}
constexpr dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}
template <int D, int cols_per_block, int nwarps, int parallel_blocks> void launch_fattn_f16_impl(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool);
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
constexpr dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
}
constexpr dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}
template <int D, int cols_per_block, int nwarps> void launch_fattn_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
) {
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
if (4*blocks_num_pb1 < 2*nsm) {
launch_fattn_f16_impl<D, cols_per_block, nwarps, 4>(Q, K, V, KQV, mask, pool, main_stream);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
launch_fattn_f16_impl<D, cols_per_block, nwarps, 2>(Q, K, V, KQV, mask, pool, main_stream);
return;
}
launch_fattn_f16_impl<D, cols_per_block, nwarps, 1>(Q, K, V, KQV, mask, pool, main_stream);
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
@ -583,259 +676,106 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_set_device(ctx.device);
const cudaStream_t main_stream = ctx.stream();
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
if (Q->ne[1] == 1) {
if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int parallel_blocks = 4;
ggml_cuda_pool_alloc<float> dst_tmp(ctx.pool());
ggml_cuda_pool_alloc<half2> dst_tmp_meta(ctx.pool());
const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE;
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const int shmem = 0;
// Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead:
constexpr int nwarps_tc = 4;
constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1);
const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z);
const dim3 block_dim_combine(Q->ne[0], 1, 1);
const int shmem_combine = 0;
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
switch (Q->ne[0]) {
case 64:
flash_attn_vec_ext_f16<64, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<64, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break;
case 80:
flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks>
<<<blocks_num, block_dim_tc, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<80, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 96:
flash_attn_vec_ext_f16<96, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<96, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break;
case 112:
flash_attn_vec_ext_f16<112, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<112, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
flash_attn_vec_ext_f16<128, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<128, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 256:
flash_attn_vec_ext_f16<256, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<256, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
CUDA_CHECK(cudaGetLastError());
return;
}
int cols_per_block;
if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
cols_per_block = 32;
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
cols_per_block = 16;
} else {
cols_per_block = 8;
}
constexpr int nwarps = 4;
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const size_t shmem = 0;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8;
constexpr int nwarps = 4;
switch (Q->ne[0]) {
case 64:
launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 96:
launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 256:
launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
constexpr int nwarps = 4;
switch (Q->ne[0]) {
case 64:
launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 80:
launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 96:
launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 112:
launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 256:
launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return;
}
constexpr int cols_per_block = 32;
constexpr int nwarps = 4;
switch (Q->ne[0]) {
case 64: switch (cols_per_block) {
FATTN_SWITCH_CASE(64, 8, nwarps);
FATTN_SWITCH_CASE(64, 16, nwarps);
FATTN_SWITCH_CASE(64, 32, nwarps);
default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false);
break;
} break;
case 80: switch (cols_per_block) {
// FATTN_SWITCH_CASE(80, 8, nwarps);
FATTN_SWITCH_CASE(80, 16, nwarps);
FATTN_SWITCH_CASE(80, 32, nwarps);
default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false);
break;
} break;
case 96: switch (cols_per_block) {
FATTN_SWITCH_CASE(96, 8, nwarps);
FATTN_SWITCH_CASE(96, 16, nwarps);
FATTN_SWITCH_CASE(96, 32, nwarps);
default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false);
break;
} break;
case 112: switch (cols_per_block) {
// FATTN_SWITCH_CASE(112, 8, nwarps);
FATTN_SWITCH_CASE(112, 16, nwarps);
FATTN_SWITCH_CASE(112, 32, nwarps);
default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false);
break;
} break;
case 128: switch (cols_per_block) {
FATTN_SWITCH_CASE(128, 8, nwarps);
FATTN_SWITCH_CASE(128, 16, nwarps);
FATTN_SWITCH_CASE(128, 32, nwarps);
default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false);
break;
} break;
case 256: switch (cols_per_block) {
FATTN_SWITCH_CASE(256, 8, nwarps);
FATTN_SWITCH_CASE(256, 16, nwarps);
FATTN_SWITCH_CASE(256, 32, nwarps);
default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false);
break;
} break;
case 64:
launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 80:
launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 96:
launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 112:
launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
case 256:
launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
CUDA_CHECK(cudaGetLastError());
return;
}