Multiple parallel blocks for batch size 1
This commit is contained in:
parent
68d793bee8
commit
3f777acf06
1 changed files with 285 additions and 136 deletions
|
@ -29,14 +29,17 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int D> // D == head size
|
#define FATTN_KQ_STRIDE 256
|
||||||
__launch_bounds__(D, 1)
|
|
||||||
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
|
||||||
static __global__ void flash_attn_vec_ext_f16(
|
static __global__ void flash_attn_vec_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
|
half2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
|
@ -60,20 +63,25 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne31*blockIdx.x;
|
const half * maskh = (const half *) mask;
|
||||||
|
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
Q_f2 += blockIdx.x*nb01/sizeof(float2);
|
||||||
|
maskh += blockIdx.x*ne11;
|
||||||
|
}
|
||||||
|
|
||||||
const int stride_KV = nb11 / sizeof(half);
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
__builtin_assume(tid < D);
|
__builtin_assume(tid < nwarps*WARP_SIZE);
|
||||||
|
|
||||||
__shared__ half KQ[D];
|
__shared__ half KQ[nwarps*WARP_SIZE];
|
||||||
KQ[tid] = 0.0f;
|
KQ[tid] = -INFINITY;
|
||||||
half2 * KQ2 = (half2 *) KQ;
|
half2 * KQ2 = (half2 *) KQ;
|
||||||
|
|
||||||
half kqmax = -INFINITY;
|
half kqmax = -INFINITY;
|
||||||
|
@ -85,7 +93,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
kqmax_shared[threadIdx.x] = -INFINITY;
|
kqmax_shared[threadIdx.x] = -INFINITY;
|
||||||
kqsum_shared[threadIdx.x] = 0.0f;
|
kqsum_shared[threadIdx.x] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Convert Q to half2 and store in registers:
|
// Convert Q to half2 and store in registers:
|
||||||
|
@ -102,14 +109,15 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
|
|
||||||
half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
|
half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
|
||||||
|
|
||||||
for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) {
|
const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D;
|
||||||
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
half kqmax_new = kqmax;
|
half kqmax_new = kqmax;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) {
|
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,6 +161,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
if (tid < D) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < D; k0 += 2) {
|
for (int k0 = 0; k0 < D; k0 += 2) {
|
||||||
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||||
|
@ -165,6 +174,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
VKQ += V_k*KQ2[k0/2];
|
VKQ += V_k*KQ2[k0/2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid >= D) {
|
||||||
|
kqsum = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
kqsum = warp_reduce_sum(kqsum);
|
kqsum = warp_reduce_sum(kqsum);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
|
@ -174,12 +188,22 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
kqsum = kqsum_shared[threadIdx.x];
|
kqsum = kqsum_shared[threadIdx.x];
|
||||||
kqsum = warp_reduce_sum(kqsum);
|
kqsum = warp_reduce_sum(kqsum);
|
||||||
|
|
||||||
|
if (tid >= D) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum;
|
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum;
|
||||||
|
} else {
|
||||||
|
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ));
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE 256
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
|
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
|
@ -187,6 +211,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
|
half2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
|
@ -228,10 +253,15 @@ static __global__ void flash_attn_ext_f16(
|
||||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x);
|
const float * Q_f = (const float *) (Q + nb02* blockIdx.y);
|
||||||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2;
|
const half2 * mask2 = (half2 *) mask;
|
||||||
|
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
Q_f += blockIdx.x * ncols*nb01/sizeof(float);
|
||||||
|
mask2 += blockIdx.x * ncols*ne11/2;
|
||||||
|
}
|
||||||
|
|
||||||
const int stride_Q = nb01 / sizeof(float);
|
const int stride_Q = nb01 / sizeof(float);
|
||||||
const int stride_KV = nb11 / sizeof(half);
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
|
@ -273,7 +303,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
if (i0 + WARP_SIZE > D && i >= D) {
|
if (i0 + WARP_SIZE > D && i >= D) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||||
|
} else {
|
||||||
|
KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,7 +325,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Iterate over ne11 == previous tokens:
|
// Iterate over ne11 == previous tokens:
|
||||||
for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) {
|
const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE;
|
||||||
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
||||||
// Calculate tile of KQ:
|
// Calculate tile of KQ:
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||||
|
@ -420,6 +455,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
@ -436,6 +472,58 @@ static __global__ void flash_attn_ext_f16(
|
||||||
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
|
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
||||||
|
if (i0 + nwarps*WARP_SIZE > D && i >= D) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.y == 0 && threadIdx.x == 0) {
|
||||||
|
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(
|
||||||
|
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
__launch_bounds__(D, 1)
|
||||||
|
static __global__ void flash_attn_combine_results(
|
||||||
|
const float * __restrict__ VKQ_parts,
|
||||||
|
const half2 * __restrict__ VKQ_meta,
|
||||||
|
float * __restrict__ dst) {
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
__builtin_assume(tid < D);
|
||||||
|
|
||||||
|
__shared__ half2 meta[parallel_blocks];
|
||||||
|
if (tid < parallel_blocks) {
|
||||||
|
meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
half kqmax = __low2half(meta[0]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 1; l < parallel_blocks; ++l) {
|
||||||
|
kqmax = __hmax(kqmax, __low2half(meta[l]));
|
||||||
|
}
|
||||||
|
|
||||||
|
float VKQ_numerator = 0.0f;
|
||||||
|
float VKQ_denominator = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < parallel_blocks; ++l) {
|
||||||
|
float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax);
|
||||||
|
|
||||||
|
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||||||
|
VKQ_denominator += KQ_max_scale * __high2float(meta[l]);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int get_max_power_of_2(int x) {
|
constexpr int get_max_power_of_2(int x) {
|
||||||
|
@ -465,13 +553,13 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||||
#define FATTN_SWITCH_CASE(D, ncols, nwarps) \
|
#define FATTN_SWITCH_CASE(D, ncols, nwarps) \
|
||||||
case ncols: { \
|
case ncols: { \
|
||||||
constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \
|
constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \
|
||||||
flash_attn_ext_f16<D, ncols, nwarps, get_VKQ_stride(D, nwarps, frag_m)> \
|
flash_attn_ext_f16<D, ncols, nwarps, get_VKQ_stride(D, nwarps, frag_m), 1> \
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> ( \
|
<<<blocks_num, block_dim, shmem, main_stream>>> ( \
|
||||||
(const char *) Q->data, \
|
(const char *) Q->data, \
|
||||||
(const char *) K->data, \
|
(const char *) K->data, \
|
||||||
(const char *) V->data, \
|
(const char *) V->data, \
|
||||||
mask ? ((const char *) mask->data) : nullptr, \
|
mask ? ((const char *) mask->data) : nullptr, \
|
||||||
(float *) KQV->data, \
|
(float *) KQV->data, nullptr, \
|
||||||
scale, \
|
scale, \
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], \
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3], \
|
||||||
|
@ -508,88 +596,39 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
const int nwarps = Q->ne[0] / WARP_SIZE;
|
constexpr int parallel_blocks = 4;
|
||||||
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
|
|
||||||
|
ggml_cuda_pool_alloc<float> dst_tmp(ctx.pool());
|
||||||
|
ggml_cuda_pool_alloc<half2> dst_tmp_meta(ctx.pool());
|
||||||
|
|
||||||
|
const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
const int shmem = 0;
|
const int shmem = 0;
|
||||||
|
|
||||||
|
// Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead:
|
||||||
|
constexpr int nwarps_tc = 4;
|
||||||
|
constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1);
|
||||||
|
|
||||||
|
const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z);
|
||||||
|
const dim3 block_dim_combine(Q->ne[0], 1, 1);
|
||||||
|
const int shmem_combine = 0;
|
||||||
|
|
||||||
|
if (parallel_blocks > 1) {
|
||||||
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||||
|
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||||
|
}
|
||||||
|
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
// case 64:
|
case 64:
|
||||||
// flash_attn_vec_ext_f16<64>
|
flash_attn_vec_ext_f16<64, parallel_blocks>
|
||||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
||||||
// (const char *) Q->data, // Query
|
|
||||||
// (const char *) K->data, // Key
|
|
||||||
// (const char *) V->data, // Value
|
|
||||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
|
||||||
// (float *) KQV->data, // dst
|
|
||||||
// scale,
|
|
||||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
||||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
||||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
||||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
|
||||||
// K->nb[1], K->nb[2], K->nb[3],
|
|
||||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
||||||
// );
|
|
||||||
// break;
|
|
||||||
// case 80:
|
|
||||||
// flash_attn_vec_ext_f16<80>
|
|
||||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
||||||
// (const char *) Q->data, // Query
|
|
||||||
// (const char *) K->data, // Key
|
|
||||||
// (const char *) V->data, // Value
|
|
||||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
|
||||||
// (float *) KQV->data, // dst
|
|
||||||
// scale,
|
|
||||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
||||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
||||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
||||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
|
||||||
// K->nb[1], K->nb[2], K->nb[3],
|
|
||||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
||||||
// );
|
|
||||||
// break;
|
|
||||||
// case 96:
|
|
||||||
// flash_attn_vec_ext_f16<96>
|
|
||||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
||||||
// (const char *) Q->data, // Query
|
|
||||||
// (const char *) K->data, // Key
|
|
||||||
// (const char *) V->data, // Value
|
|
||||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
|
||||||
// (float *) KQV->data, // dst
|
|
||||||
// scale,
|
|
||||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
||||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
||||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
||||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
|
||||||
// K->nb[1], K->nb[2], K->nb[3],
|
|
||||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
||||||
// );
|
|
||||||
// break;
|
|
||||||
// case 112:
|
|
||||||
// flash_attn_vec_ext_f16<112>
|
|
||||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
||||||
// (const char *) Q->data, // Query
|
|
||||||
// (const char *) K->data, // Key
|
|
||||||
// (const char *) V->data, // Value
|
|
||||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
|
||||||
// (float *) KQV->data, // dst
|
|
||||||
// scale,
|
|
||||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
||||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
||||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
||||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
|
||||||
// K->nb[1], K->nb[2], K->nb[3],
|
|
||||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
||||||
// );
|
|
||||||
// break;
|
|
||||||
case 128:
|
|
||||||
flash_attn_vec_ext_f16<128>
|
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) Q->data, // Query
|
(const char *) Q->data, // Query
|
||||||
(const char *) K->data, // Key
|
(const char *) K->data, // Key
|
||||||
(const char *) V->data, // Value
|
(const char *) V->data, // Value
|
||||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
(float *) KQV->data, // dst
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale,
|
scale,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
@ -598,15 +637,118 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
);
|
);
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
flash_attn_combine_results<64, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks>
|
||||||
|
<<<blocks_num, block_dim_tc, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data, // Query
|
||||||
|
(const char *) K->data, // Key
|
||||||
|
(const char *) V->data, // Value
|
||||||
|
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
flash_attn_combine_results<80, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
flash_attn_vec_ext_f16<96, parallel_blocks>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data, // Query
|
||||||
|
(const char *) K->data, // Key
|
||||||
|
(const char *) V->data, // Value
|
||||||
|
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
flash_attn_combine_results<96, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
flash_attn_vec_ext_f16<112, parallel_blocks>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data, // Query
|
||||||
|
(const char *) K->data, // Key
|
||||||
|
(const char *) V->data, // Value
|
||||||
|
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
flash_attn_combine_results<112, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
flash_attn_vec_ext_f16<128, parallel_blocks>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data, // Query
|
||||||
|
(const char *) K->data, // Key
|
||||||
|
(const char *) V->data, // Value
|
||||||
|
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
flash_attn_combine_results<128, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
flash_attn_vec_ext_f16<256>
|
flash_attn_vec_ext_f16<256, parallel_blocks>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) Q->data, // Query
|
(const char *) Q->data, // Query
|
||||||
(const char *) K->data, // Key
|
(const char *) K->data, // Key
|
||||||
(const char *) V->data, // Value
|
(const char *) V->data, // Value
|
||||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
(float *) KQV->data, // dst
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale,
|
scale,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
@ -615,6 +757,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
);
|
);
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
flash_attn_combine_results<256, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -633,7 +782,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
cols_per_block = 8;
|
cols_per_block = 8;
|
||||||
}
|
}
|
||||||
const int frag_m = cols_per_block == 8 ? 32 : 16;
|
const int frag_m = cols_per_block == 8 ? 32 : 16;
|
||||||
const int nwarps = 4;
|
constexpr int nwarps = 4;
|
||||||
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
||||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
const size_t shmem = 0;
|
const size_t shmem = 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue