fused attention kernel for batch size 1

This commit is contained in:
Johannes Gäßler 2024-03-19 21:04:28 +01:00
parent 58c7f6167c
commit 82ae7f3357

View file

@ -7227,6 +7227,157 @@ static __global__ void flash_attn_f32(
} }
} }
template<int D, int nwarps, int need_check> // D head size
__launch_bounds__(WARP_SIZE*nwarps, 1)
static __global__ void fused_attn_vec_ext_f16(
const char * __restrict__ q,
const char * __restrict__ k,
const char * __restrict__ v,
const char * __restrict__ mask,
float * __restrict__ dst,
const float scale,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int & nb21 = nb11;
const int & nb22 = nb12;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
extern __shared__ char data_flash_attn_vec_ext_f16[];
half * kq = (half *) data_flash_attn_vec_ext_f16;
half2 * kq2 = (half2 *) kq;
half * buf_iw = kq + ne11;
if (threadIdx.y == 0) {
buf_iw[threadIdx.x] = 0.0f;
}
__syncthreads();
half2 qh2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if ((D/2) % WARP_SIZE != 0 && i >= (D/2)) {
break;
}
const float2 qf2 = ((const float2 *) (q + nb02*blockIdx.y + nb01*blockIdx.x))[i];
qh2[i0/WARP_SIZE] = make_half2(qf2.x, qf2.y);
}
half kqmax = -INFINITY;
for (int i0 = 0; i0 < ne11; i0 += nwarps) {
const int i = i0 + threadIdx.y;
if (need_check > 1 && i >= ne11) {
break;
}
half2 sum2 = mask ? __half2half2(((half *) mask)[ne11*blockIdx.x + i]) : make_half2(0.0f, 0.0f);
#pragma unroll
for (int j0 = 0; j0 < D/2; j0 += WARP_SIZE) {
const int j = j0 + threadIdx.x;
if ((D/2) % WARP_SIZE != 0 && j >= (D/2)) {
break;
}
const half2 k2k = ((const half2 *) (k + nb12*blockIdx.y + nb11*i))[j];
sum2 += k2k * qh2[j0/WARP_SIZE];
}
sum2 = warp_reduce_sum(sum2);
const half sum = __float2half(scale) * (__low2half(sum2) + __high2half(sum2));
kqmax = __hmax(kqmax, sum);
if (threadIdx.x == 0) {
kq[i] = sum;
}
}
kqmax = warp_reduce_max(kqmax);
if (threadIdx.x == 0) {
buf_iw[threadIdx.y] = kqmax;
}
__syncthreads();
kqmax = buf_iw[threadIdx.x];
__syncthreads();
kqmax = warp_reduce_max(kqmax);
const half2 kqmax2 = __half2half2(kqmax);
half2 kqsum2 = make_half2(0.0f, 0.0f);
for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) {
const int i = i0 + tid;
if (need_check > 0 && i >= ne11/2) {
break;
}
const half2 val = h2exp(kq2[i] - kqmax2);
kqsum2 += val;
kq2[i] = val;
}
kqsum2 = warp_reduce_sum(kqsum2);
if (threadIdx.x == 0) {
buf_iw[threadIdx.y] = __low2half(kqsum2) + __high2half(kqsum2);
}
__syncthreads();
half kqsum = buf_iw[threadIdx.x];
kqsum = warp_reduce_sum(kqsum);
const half2 kqscale = make_half2(1.0f, 1.0f) / __half2half2(kqsum);
for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) {
const int i = i0 + tid;
if (need_check > 0 && i >= ne11/2) {
break;
}
kq2[i] *= kqscale;
}
__syncthreads();
half2 sum2 = make_half2(0.0f, 0.0f);
for (int i0 = 0; i0 < ne11/2; i0 += nwarps / (D/WARP_SIZE)) {
const int i = i0 + threadIdx.y / (D/WARP_SIZE);
if (need_check > 2 && i >= ne11/2) {
break;
}
half2 vi;
vi.x = ((const half *) (v + nb22*blockIdx.y + (2*i + 0)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x];
vi.y = ((const half *) (v + nb22*blockIdx.y + (2*i + 1)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x];
sum2 += vi*kq2[i];
}
buf_iw[tid] = __low2half(sum2) + __high2half(sum2);
__syncthreads();
if (threadIdx.y >= (D/WARP_SIZE)) {
return;
}
float sum = 0.0f;
#pragma unroll
for (int i = 0; i < nwarps; i += (D/WARP_SIZE)) {
sum += __half2float(buf_iw[WARP_SIZE*i + tid]);
}
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = sum;
}
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
@ -11927,6 +12078,83 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
//const size_t shmem_max = 96*1024; //const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
if (Q->ne[0] == 128 && Q->ne[1] <= 2) {
constexpr int nwarps = 32;
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const size_t shmem = (K->ne[1] + WARP_SIZE*nwarps)*sizeof(half);
GGML_ASSERT(K->ne[1] % 2 == 0);
if ((K->ne[1]/2) % (WARP_SIZE*nwarps) == 0) {
fused_attn_vec_ext_f16<128, nwarps, 0>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
} else if (K->ne[1] % nwarps == 0 && (K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) {
fused_attn_vec_ext_f16<128, nwarps, 1>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
} else if ((K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) {
fused_attn_vec_ext_f16<128, nwarps, 2>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
} else {
fused_attn_vec_ext_f16<128, nwarps, 3>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
}
CUDA_CHECK(cudaGetLastError());
return;
}
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
flash_attn_ext_f16<64, NQPB, NCPW> flash_attn_ext_f16<64, NQPB, NCPW>