fused attention kernel for batch size 1
This commit is contained in:
parent
58c7f6167c
commit
82ae7f3357
1 changed files with 228 additions and 0 deletions
228
ggml-cuda.cu
228
ggml-cuda.cu
|
@ -7227,6 +7227,157 @@ static __global__ void flash_attn_f32(
|
|||
}
|
||||
}
|
||||
|
||||
template<int D, int nwarps, int need_check> // D head size
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||
static __global__ void fused_attn_vec_ext_f16(
|
||||
const char * __restrict__ q,
|
||||
const char * __restrict__ k,
|
||||
const char * __restrict__ v,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
const float scale,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int & nb21 = nb11;
|
||||
const int & nb22 = nb12;
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
|
||||
extern __shared__ char data_flash_attn_vec_ext_f16[];
|
||||
half * kq = (half *) data_flash_attn_vec_ext_f16;
|
||||
half2 * kq2 = (half2 *) kq;
|
||||
half * buf_iw = kq + ne11;
|
||||
|
||||
if (threadIdx.y == 0) {
|
||||
buf_iw[threadIdx.x] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
half2 qh2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if ((D/2) % WARP_SIZE != 0 && i >= (D/2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float2 qf2 = ((const float2 *) (q + nb02*blockIdx.y + nb01*blockIdx.x))[i];
|
||||
qh2[i0/WARP_SIZE] = make_half2(qf2.x, qf2.y);
|
||||
}
|
||||
|
||||
half kqmax = -INFINITY;
|
||||
for (int i0 = 0; i0 < ne11; i0 += nwarps) {
|
||||
const int i = i0 + threadIdx.y;
|
||||
if (need_check > 1 && i >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 sum2 = mask ? __half2half2(((half *) mask)[ne11*blockIdx.x + i]) : make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < D/2; j0 += WARP_SIZE) {
|
||||
const int j = j0 + threadIdx.x;
|
||||
if ((D/2) % WARP_SIZE != 0 && j >= (D/2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
const half2 k2k = ((const half2 *) (k + nb12*blockIdx.y + nb11*i))[j];
|
||||
sum2 += k2k * qh2[j0/WARP_SIZE];
|
||||
}
|
||||
sum2 = warp_reduce_sum(sum2);
|
||||
const half sum = __float2half(scale) * (__low2half(sum2) + __high2half(sum2));
|
||||
kqmax = __hmax(kqmax, sum);
|
||||
if (threadIdx.x == 0) {
|
||||
kq[i] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
kqmax = warp_reduce_max(kqmax);
|
||||
if (threadIdx.x == 0) {
|
||||
buf_iw[threadIdx.y] = kqmax;
|
||||
}
|
||||
__syncthreads();
|
||||
kqmax = buf_iw[threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
kqmax = warp_reduce_max(kqmax);
|
||||
const half2 kqmax2 = __half2half2(kqmax);
|
||||
|
||||
half2 kqsum2 = make_half2(0.0f, 0.0f);
|
||||
for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) {
|
||||
const int i = i0 + tid;
|
||||
if (need_check > 0 && i >= ne11/2) {
|
||||
break;
|
||||
}
|
||||
|
||||
const half2 val = h2exp(kq2[i] - kqmax2);
|
||||
kqsum2 += val;
|
||||
kq2[i] = val;
|
||||
}
|
||||
kqsum2 = warp_reduce_sum(kqsum2);
|
||||
if (threadIdx.x == 0) {
|
||||
buf_iw[threadIdx.y] = __low2half(kqsum2) + __high2half(kqsum2);
|
||||
}
|
||||
__syncthreads();
|
||||
half kqsum = buf_iw[threadIdx.x];
|
||||
kqsum = warp_reduce_sum(kqsum);
|
||||
const half2 kqscale = make_half2(1.0f, 1.0f) / __half2half2(kqsum);
|
||||
|
||||
for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) {
|
||||
const int i = i0 + tid;
|
||||
if (need_check > 0 && i >= ne11/2) {
|
||||
break;
|
||||
}
|
||||
|
||||
kq2[i] *= kqscale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
half2 sum2 = make_half2(0.0f, 0.0f);
|
||||
for (int i0 = 0; i0 < ne11/2; i0 += nwarps / (D/WARP_SIZE)) {
|
||||
const int i = i0 + threadIdx.y / (D/WARP_SIZE);
|
||||
if (need_check > 2 && i >= ne11/2) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 vi;
|
||||
vi.x = ((const half *) (v + nb22*blockIdx.y + (2*i + 0)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x];
|
||||
vi.y = ((const half *) (v + nb22*blockIdx.y + (2*i + 1)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x];
|
||||
sum2 += vi*kq2[i];
|
||||
}
|
||||
buf_iw[tid] = __low2half(sum2) + __high2half(sum2);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.y >= (D/WARP_SIZE)) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < nwarps; i += (D/WARP_SIZE)) {
|
||||
sum += __half2float(buf_iw[WARP_SIZE*i + tid]);
|
||||
}
|
||||
|
||||
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = sum;
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||
|
@ -11927,6 +12078,83 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
//const size_t shmem_max = 96*1024;
|
||||
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
|
||||
if (Q->ne[0] == 128 && Q->ne[1] <= 2) {
|
||||
constexpr int nwarps = 32;
|
||||
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const size_t shmem = (K->ne[1] + WARP_SIZE*nwarps)*sizeof(half);
|
||||
|
||||
GGML_ASSERT(K->ne[1] % 2 == 0);
|
||||
if ((K->ne[1]/2) % (WARP_SIZE*nwarps) == 0) {
|
||||
fused_attn_vec_ext_f16<128, nwarps, 0>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} else if (K->ne[1] % nwarps == 0 && (K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) {
|
||||
fused_attn_vec_ext_f16<128, nwarps, 1>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} else if ((K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) {
|
||||
fused_attn_vec_ext_f16<128, nwarps, 2>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
} else {
|
||||
fused_attn_vec_ext_f16<128, nwarps, 3>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
return;
|
||||
}
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue