fused attention kernel for batch size 1
This commit is contained in:
parent
58c7f6167c
commit
82ae7f3357
1 changed files with 228 additions and 0 deletions
228
ggml-cuda.cu
228
ggml-cuda.cu
|
@ -7227,6 +7227,157 @@ static __global__ void flash_attn_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int D, int nwarps, int need_check> // D head size
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
static __global__ void fused_attn_vec_ext_f16(
|
||||||
|
const char * __restrict__ q,
|
||||||
|
const char * __restrict__ k,
|
||||||
|
const char * __restrict__ v,
|
||||||
|
const char * __restrict__ mask,
|
||||||
|
float * __restrict__ dst,
|
||||||
|
const float scale,
|
||||||
|
const int ne00,
|
||||||
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int ne03,
|
||||||
|
const int ne10,
|
||||||
|
const int ne11,
|
||||||
|
const int ne12,
|
||||||
|
const int ne13,
|
||||||
|
const int ne31,
|
||||||
|
const int nb31,
|
||||||
|
const int nb01,
|
||||||
|
const int nb02,
|
||||||
|
const int nb03,
|
||||||
|
const int nb11,
|
||||||
|
const int nb12,
|
||||||
|
const int nb13,
|
||||||
|
const int ne0,
|
||||||
|
const int ne1,
|
||||||
|
const int ne2,
|
||||||
|
const int ne3) {
|
||||||
|
const int & nb21 = nb11;
|
||||||
|
const int & nb22 = nb12;
|
||||||
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
|
||||||
|
extern __shared__ char data_flash_attn_vec_ext_f16[];
|
||||||
|
half * kq = (half *) data_flash_attn_vec_ext_f16;
|
||||||
|
half2 * kq2 = (half2 *) kq;
|
||||||
|
half * buf_iw = kq + ne11;
|
||||||
|
|
||||||
|
if (threadIdx.y == 0) {
|
||||||
|
buf_iw[threadIdx.x] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
half2 qh2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
if ((D/2) % WARP_SIZE != 0 && i >= (D/2)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float2 qf2 = ((const float2 *) (q + nb02*blockIdx.y + nb01*blockIdx.x))[i];
|
||||||
|
qh2[i0/WARP_SIZE] = make_half2(qf2.x, qf2.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
half kqmax = -INFINITY;
|
||||||
|
for (int i0 = 0; i0 < ne11; i0 += nwarps) {
|
||||||
|
const int i = i0 + threadIdx.y;
|
||||||
|
if (need_check > 1 && i >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 sum2 = mask ? __half2half2(((half *) mask)[ne11*blockIdx.x + i]) : make_half2(0.0f, 0.0f);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < D/2; j0 += WARP_SIZE) {
|
||||||
|
const int j = j0 + threadIdx.x;
|
||||||
|
if ((D/2) % WARP_SIZE != 0 && j >= (D/2)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const half2 k2k = ((const half2 *) (k + nb12*blockIdx.y + nb11*i))[j];
|
||||||
|
sum2 += k2k * qh2[j0/WARP_SIZE];
|
||||||
|
}
|
||||||
|
sum2 = warp_reduce_sum(sum2);
|
||||||
|
const half sum = __float2half(scale) * (__low2half(sum2) + __high2half(sum2));
|
||||||
|
kqmax = __hmax(kqmax, sum);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kq[i] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kqmax = warp_reduce_max(kqmax);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
buf_iw[threadIdx.y] = kqmax;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
kqmax = buf_iw[threadIdx.x];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
kqmax = warp_reduce_max(kqmax);
|
||||||
|
const half2 kqmax2 = __half2half2(kqmax);
|
||||||
|
|
||||||
|
half2 kqsum2 = make_half2(0.0f, 0.0f);
|
||||||
|
for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) {
|
||||||
|
const int i = i0 + tid;
|
||||||
|
if (need_check > 0 && i >= ne11/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const half2 val = h2exp(kq2[i] - kqmax2);
|
||||||
|
kqsum2 += val;
|
||||||
|
kq2[i] = val;
|
||||||
|
}
|
||||||
|
kqsum2 = warp_reduce_sum(kqsum2);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
buf_iw[threadIdx.y] = __low2half(kqsum2) + __high2half(kqsum2);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
half kqsum = buf_iw[threadIdx.x];
|
||||||
|
kqsum = warp_reduce_sum(kqsum);
|
||||||
|
const half2 kqscale = make_half2(1.0f, 1.0f) / __half2half2(kqsum);
|
||||||
|
|
||||||
|
for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) {
|
||||||
|
const int i = i0 + tid;
|
||||||
|
if (need_check > 0 && i >= ne11/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
kq2[i] *= kqscale;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
half2 sum2 = make_half2(0.0f, 0.0f);
|
||||||
|
for (int i0 = 0; i0 < ne11/2; i0 += nwarps / (D/WARP_SIZE)) {
|
||||||
|
const int i = i0 + threadIdx.y / (D/WARP_SIZE);
|
||||||
|
if (need_check > 2 && i >= ne11/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 vi;
|
||||||
|
vi.x = ((const half *) (v + nb22*blockIdx.y + (2*i + 0)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x];
|
||||||
|
vi.y = ((const half *) (v + nb22*blockIdx.y + (2*i + 1)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x];
|
||||||
|
sum2 += vi*kq2[i];
|
||||||
|
}
|
||||||
|
buf_iw[tid] = __low2half(sum2) + __high2half(sum2);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.y >= (D/WARP_SIZE)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < nwarps; i += (D/WARP_SIZE)) {
|
||||||
|
sum += __half2float(buf_iw[WARP_SIZE*i + tid]);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||||
|
@ -11927,6 +12078,83 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
//const size_t shmem_max = 96*1024;
|
//const size_t shmem_max = 96*1024;
|
||||||
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||||
|
|
||||||
|
if (Q->ne[0] == 128 && Q->ne[1] <= 2) {
|
||||||
|
constexpr int nwarps = 32;
|
||||||
|
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||||
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
|
const size_t shmem = (K->ne[1] + WARP_SIZE*nwarps)*sizeof(half);
|
||||||
|
|
||||||
|
GGML_ASSERT(K->ne[1] % 2 == 0);
|
||||||
|
if ((K->ne[1]/2) % (WARP_SIZE*nwarps) == 0) {
|
||||||
|
fused_attn_vec_ext_f16<128, nwarps, 0>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
} else if (K->ne[1] % nwarps == 0 && (K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) {
|
||||||
|
fused_attn_vec_ext_f16<128, nwarps, 1>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
} else if ((K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) {
|
||||||
|
fused_attn_vec_ext_f16<128, nwarps, 2>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
fused_attn_vec_ext_f16<128, nwarps, 3>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64:
|
case 64:
|
||||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue