From 82ae7f33575977bee4c3f45bceb0bd7580724c7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 19 Mar 2024 21:04:28 +0100 Subject: [PATCH] fused attention kernel for batch size 1 --- ggml-cuda.cu | 228 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 93fb7e80d..6d0a6bdad 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7227,6 +7227,157 @@ static __global__ void flash_attn_f32( } } +template // D head size +__launch_bounds__(WARP_SIZE*nwarps, 1) +static __global__ void fused_attn_vec_ext_f16( + const char * __restrict__ q, + const char * __restrict__ k, + const char * __restrict__ v, + const char * __restrict__ mask, + float * __restrict__ dst, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + const int & nb21 = nb11; + const int & nb22 = nb12; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + + extern __shared__ char data_flash_attn_vec_ext_f16[]; + half * kq = (half *) data_flash_attn_vec_ext_f16; + half2 * kq2 = (half2 *) kq; + half * buf_iw = kq + ne11; + + if (threadIdx.y == 0) { + buf_iw[threadIdx.x] = 0.0f; + } + __syncthreads(); + + half2 qh2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if ((D/2) % WARP_SIZE != 0 && i >= (D/2)) { + break; + } + + const float2 qf2 = ((const float2 *) (q + nb02*blockIdx.y + nb01*blockIdx.x))[i]; + qh2[i0/WARP_SIZE] = make_half2(qf2.x, qf2.y); + } + + half kqmax = -INFINITY; + for (int i0 = 0; i0 < ne11; i0 += nwarps) { + const int i = i0 + threadIdx.y; + if (need_check > 1 && i >= ne11) { + break; + } + + half2 sum2 = mask ? __half2half2(((half *) mask)[ne11*blockIdx.x + i]) : make_half2(0.0f, 0.0f); +#pragma unroll + for (int j0 = 0; j0 < D/2; j0 += WARP_SIZE) { + const int j = j0 + threadIdx.x; + if ((D/2) % WARP_SIZE != 0 && j >= (D/2)) { + break; + } + + const half2 k2k = ((const half2 *) (k + nb12*blockIdx.y + nb11*i))[j]; + sum2 += k2k * qh2[j0/WARP_SIZE]; + } + sum2 = warp_reduce_sum(sum2); + const half sum = __float2half(scale) * (__low2half(sum2) + __high2half(sum2)); + kqmax = __hmax(kqmax, sum); + if (threadIdx.x == 0) { + kq[i] = sum; + } + } + + kqmax = warp_reduce_max(kqmax); + if (threadIdx.x == 0) { + buf_iw[threadIdx.y] = kqmax; + } + __syncthreads(); + kqmax = buf_iw[threadIdx.x]; + __syncthreads(); + + kqmax = warp_reduce_max(kqmax); + const half2 kqmax2 = __half2half2(kqmax); + + half2 kqsum2 = make_half2(0.0f, 0.0f); + for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) { + const int i = i0 + tid; + if (need_check > 0 && i >= ne11/2) { + break; + } + + const half2 val = h2exp(kq2[i] - kqmax2); + kqsum2 += val; + kq2[i] = val; + } + kqsum2 = warp_reduce_sum(kqsum2); + if (threadIdx.x == 0) { + buf_iw[threadIdx.y] = __low2half(kqsum2) + __high2half(kqsum2); + } + __syncthreads(); + half kqsum = buf_iw[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); + const half2 kqscale = make_half2(1.0f, 1.0f) / __half2half2(kqsum); + + for (int i0 = 0; i0 < ne11/2; i0 += WARP_SIZE*nwarps) { + const int i = i0 + tid; + if (need_check > 0 && i >= ne11/2) { + break; + } + + kq2[i] *= kqscale; + } + __syncthreads(); + + half2 sum2 = make_half2(0.0f, 0.0f); + for (int i0 = 0; i0 < ne11/2; i0 += nwarps / (D/WARP_SIZE)) { + const int i = i0 + threadIdx.y / (D/WARP_SIZE); + if (need_check > 2 && i >= ne11/2) { + break; + } + + half2 vi; + vi.x = ((const half *) (v + nb22*blockIdx.y + (2*i + 0)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x]; + vi.y = ((const half *) (v + nb22*blockIdx.y + (2*i + 1)*nb21))[WARP_SIZE*(threadIdx.y % (D/WARP_SIZE)) + threadIdx.x]; + sum2 += vi*kq2[i]; + } + buf_iw[tid] = __low2half(sum2) + __high2half(sum2); + __syncthreads(); + + if (threadIdx.y >= (D/WARP_SIZE)) { + return; + } + + float sum = 0.0f; + +#pragma unroll + for (int i = 0; i < nwarps; i += (D/WARP_SIZE)) { + sum += __half2float(buf_iw[WARP_SIZE*i + tid]); + } + + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = sum; +} + #if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; @@ -11927,6 +12078,83 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * //const size_t shmem_max = 96*1024; //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + if (Q->ne[0] == 128 && Q->ne[1] <= 2) { + constexpr int nwarps = 32; + const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const size_t shmem = (K->ne[1] + WARP_SIZE*nwarps)*sizeof(half); + + GGML_ASSERT(K->ne[1] % 2 == 0); + if ((K->ne[1]/2) % (WARP_SIZE*nwarps) == 0) { + fused_attn_vec_ext_f16<128, nwarps, 0> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + } else if (K->ne[1] % nwarps == 0 && (K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) { + fused_attn_vec_ext_f16<128, nwarps, 1> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + } else if ((K->ne[1]/2) % (nwarps / (Q->ne[0] / WARP_SIZE)) == 0) { + fused_attn_vec_ext_f16<128, nwarps, 2> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + } else { + fused_attn_vec_ext_f16<128, nwarps, 3> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + + } + CUDA_CHECK(cudaGetLastError()); + return; + } + switch (Q->ne[0]) { case 64: flash_attn_ext_f16<64, NQPB, NCPW>