Multiple parallel blocks for batch size 1

This commit is contained in:
Johannes Gäßler 2024-04-01 16:41:56 +02:00 committed by Georgi Gerganov
parent 68d793bee8
commit 3f777acf06

View file

@ -29,14 +29,17 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
} }
template<int D> // D == head size #define FATTN_KQ_STRIDE 256
__launch_bounds__(D, 1)
template<int D, int parallel_blocks> // D == head size
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
static __global__ void flash_attn_vec_ext_f16( static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ Q, const char * __restrict__ Q,
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
float * __restrict__ dst, float * __restrict__ dst,
half2 * __restrict__ dst_meta,
const float scale, const float scale,
const int ne00, const int ne00,
const int ne01, const int ne01,
@ -60,20 +63,25 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne3) { const int ne3) {
//In this kernel Q, K, V are matrices while i, j, k are matrix indices. //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne31*blockIdx.x; const half * maskh = (const half *) mask;
if (parallel_blocks == 1) {
Q_f2 += blockIdx.x*nb01/sizeof(float2);
maskh += blockIdx.x*ne11;
}
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
constexpr int nwarps = D/WARP_SIZE; constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
__builtin_assume(tid < D); __builtin_assume(tid < nwarps*WARP_SIZE);
__shared__ half KQ[D]; __shared__ half KQ[nwarps*WARP_SIZE];
KQ[tid] = 0.0f; KQ[tid] = -INFINITY;
half2 * KQ2 = (half2 *) KQ; half2 * KQ2 = (half2 *) KQ;
half kqmax = -INFINITY; half kqmax = -INFINITY;
@ -85,7 +93,6 @@ static __global__ void flash_attn_vec_ext_f16(
kqmax_shared[threadIdx.x] = -INFINITY; kqmax_shared[threadIdx.x] = -INFINITY;
kqsum_shared[threadIdx.x] = 0.0f; kqsum_shared[threadIdx.x] = 0.0f;
} }
__syncthreads(); __syncthreads();
// Convert Q to half2 and store in registers: // Convert Q to half2 and store in registers:
@ -102,14 +109,15 @@ static __global__ void flash_attn_vec_ext_f16(
half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
// Calculate KQ tile and keep track of new maximum KQ values: // Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new = kqmax; half kqmax_new = kqmax;
#pragma unroll #pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y; const int i_KQ = i_KQ_0 + threadIdx.y;
if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
break; break;
} }
@ -153,19 +161,25 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads(); __syncthreads();
if (tid < D) {
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < D; k0 += 2) { for (int k0 = 0; k0 < D; k0 += 2) {
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
break; break;
} }
half2 V_k; half2 V_k;
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
VKQ += V_k*KQ2[k0/2]; VKQ += V_k*KQ2[k0/2];
}
} }
} }
if (tid >= D) {
kqsum = 0.0f;
}
kqsum = warp_reduce_sum(kqsum); kqsum = warp_reduce_sum(kqsum);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
kqsum_shared[threadIdx.y] = kqsum; kqsum_shared[threadIdx.y] = kqsum;
@ -174,12 +188,22 @@ static __global__ void flash_attn_vec_ext_f16(
kqsum = kqsum_shared[threadIdx.x]; kqsum = kqsum_shared[threadIdx.x];
kqsum = warp_reduce_sum(kqsum); kqsum = warp_reduce_sum(kqsum);
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; if (tid >= D) {
return;
}
if (parallel_blocks == 1) {
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum;
} else {
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ));
if (tid == 0) {
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum);
}
}
} }
#define FATTN_KQ_STRIDE 256 template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
template<int D, int ncols, int nwarps, int VKQ_stride> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
__launch_bounds__(nwarps*WARP_SIZE, 1) __launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16( static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q, const char * __restrict__ Q,
@ -187,6 +211,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
float * __restrict__ dst, float * __restrict__ dst,
half2 * __restrict__ dst_meta,
const float scale, const float scale,
const int ne00, const int ne00,
const int ne01, const int ne01,
@ -228,10 +253,15 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); const float * Q_f = (const float *) (Q + nb02* blockIdx.y);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; const half2 * mask2 = (half2 *) mask;
if (parallel_blocks == 1) {
Q_f += blockIdx.x * ncols*nb01/sizeof(float);
mask2 += blockIdx.x * ncols*ne11/2;
}
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
@ -273,7 +303,11 @@ static __global__ void flash_attn_ext_f16(
if (i0 + WARP_SIZE > D && i >= D) { if (i0 + WARP_SIZE > D && i >= D) {
break; break;
} }
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; if (parallel_blocks == 1) {
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
} else {
KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
} }
} }
@ -291,7 +325,8 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
// Iterate over ne11 == previous tokens: // Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
// Calculate tile of KQ: // Calculate tile of KQ:
#pragma unroll #pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
@ -420,22 +455,75 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
} }
if (parallel_blocks == 1) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if (ncols*blockIdx.x + j >= ne01) { if (ncols*blockIdx.x + j >= ne01) {
return;
}
const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]);
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
}
}
return;
}
#pragma unroll
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
if (i0 + nwarps*WARP_SIZE > D && i >= D) {
return; return;
} }
const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i];
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
}
} }
if (threadIdx.y == 0 && threadIdx.x == 0) {
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
}
}
template<int D, int parallel_blocks> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const half2 * __restrict__ VKQ_meta,
float * __restrict__ dst) {
const int tid = threadIdx.x;
__builtin_assume(tid < D);
__shared__ half2 meta[parallel_blocks];
if (tid < parallel_blocks) {
meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid];
}
__syncthreads();
half kqmax = __low2half(meta[0]);
#pragma unroll
for (int l = 1; l < parallel_blocks; ++l) {
kqmax = __hmax(kqmax, __low2half(meta[l]));
}
float VKQ_numerator = 0.0f;
float VKQ_denominator = 0.0f;
#pragma unroll
for (int l = 0; l < parallel_blocks; ++l) {
float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax);
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
VKQ_denominator += KQ_max_scale * __high2float(meta[l]);
}
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
} }
constexpr int get_max_power_of_2(int x) { constexpr int get_max_power_of_2(int x) {
@ -462,26 +550,26 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ #define FATTN_SWITCH_CASE(D, ncols, nwarps) \
case ncols: { \ case ncols: { \
constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \
flash_attn_ext_f16<D, ncols, nwarps, get_VKQ_stride(D, nwarps, frag_m)> \ flash_attn_ext_f16<D, ncols, nwarps, get_VKQ_stride(D, nwarps, frag_m), 1> \
<<<blocks_num, block_dim, shmem, main_stream>>> ( \ <<<blocks_num, block_dim, shmem, main_stream>>> ( \
(const char *) Q->data, \ (const char *) Q->data, \
(const char *) K->data, \ (const char *) K->data, \
(const char *) V->data, \ (const char *) V->data, \
mask ? ((const char *) mask->data) : nullptr, \ mask ? ((const char *) mask->data) : nullptr, \
(float *) KQV->data, \ (float *) KQV->data, nullptr, \
scale, \ scale, \
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \
K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ K->ne[0], K->ne[1], K->ne[2], K->ne[3], \
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \
Q->nb[1], Q->nb[2], Q->nb[3], \ Q->nb[1], Q->nb[2], Q->nb[3], \
K->nb[1], K->nb[2], K->nb[3], \ K->nb[1], K->nb[2], K->nb[3], \
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \
); \ ); \
} \ } \
break; \ break; \
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
@ -508,88 +596,39 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
float scale; float scale;
memcpy(&scale, KQV->op_params, sizeof(float)); memcpy(&scale, KQV->op_params, sizeof(float));
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { if (Q->ne[1] == 1) {
const int nwarps = Q->ne[0] / WARP_SIZE; constexpr int parallel_blocks = 4;
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
ggml_cuda_pool_alloc<float> dst_tmp(ctx.pool());
ggml_cuda_pool_alloc<half2> dst_tmp_meta(ctx.pool());
const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE;
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 block_dim(WARP_SIZE, nwarps, 1);
const int shmem = 0; const int shmem = 0;
// Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead:
constexpr int nwarps_tc = 4;
constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1);
const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z);
const dim3 block_dim_combine(Q->ne[0], 1, 1);
const int shmem_combine = 0;
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
switch (Q->ne[0]) { switch (Q->ne[0]) {
// case 64: case 64:
// flash_attn_vec_ext_f16<64> flash_attn_vec_ext_f16<64, parallel_blocks>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
// (const char *) Q->data, // Query
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
// case 80:
// flash_attn_vec_ext_f16<80>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
// (const char *) Q->data, // Query
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
// case 96:
// flash_attn_vec_ext_f16<96>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
// (const char *) Q->data, // Query
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
// case 112:
// flash_attn_vec_ext_f16<112>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
// (const char *) Q->data, // Query
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
case 128:
flash_attn_vec_ext_f16<128>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query (const char *) Q->data, // Query
(const char *) K->data, // Key (const char *) K->data, // Key
(const char *) V->data, // Value (const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@ -598,15 +637,118 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3], K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
); );
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<64, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break;
case 80:
flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks>
<<<blocks_num, block_dim_tc, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<80, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break;
case 96:
flash_attn_vec_ext_f16<96, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<96, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break;
case 112:
flash_attn_vec_ext_f16<112, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<112, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break;
case 128:
flash_attn_vec_ext_f16<128, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<128, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break; break;
case 256: case 256:
flash_attn_vec_ext_f16<256> flash_attn_vec_ext_f16<256, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query (const char *) Q->data, // Query
(const char *) K->data, // Key (const char *) K->data, // Key
(const char *) V->data, // Value (const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@ -615,6 +757,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
K->nb[1], K->nb[2], K->nb[3], K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
); );
if (parallel_blocks == 1) {
break;
}
CUDA_CHECK(cudaGetLastError());
flash_attn_combine_results<256, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -633,7 +782,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
cols_per_block = 8; cols_per_block = 8;
} }
const int frag_m = cols_per_block == 8 ? 32 : 16; const int frag_m = cols_per_block == 8 ? 32 : 16;
const int nwarps = 4; constexpr int nwarps = 4;
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 block_dim(WARP_SIZE, nwarps, 1);
const size_t shmem = 0; const size_t shmem = 0;