From a1d5a12bc5ab9cde4d3db304d3882b99cca5e849 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 29 Jan 2024 13:15:33 -0500 Subject: [PATCH] fix compiler error --- ggml-cuda.cu | 48 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ecfa98c4e..8fa21c97e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6134,6 +6134,7 @@ static __global__ void flash_attn_f32( } } +#if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; typedef nvcuda::wmma::fragment half16x16_acc; @@ -6185,13 +6186,13 @@ static __global__ void flash_attn_ext_f16( const half scale_h = __float2half(scale); - extern __shared__ char data_flash_attn_shmem[]; + extern __shared__ half __flash_attn_f16_shmem[]; // pq - half * pq = (half *) (data_flash_attn_shmem + 0*D); - half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); - half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); - half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); - half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); + half * pq = (half *) (__flash_attn_f16_shmem + 0*D); + half2 * pq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); + half * ps = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); + half2 * ps2 = (half2 *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 2*D); for (int i = 0; i < L2; ++i) { // load heads from Q to shared memory @@ -6217,7 +6218,7 @@ static __global__ void flash_attn_ext_f16( } __syncthreads(); - +#if 0 { half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros?? half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization @@ -6400,7 +6401,40 @@ static __global__ void flash_attn_ext_f16( } } } +#endif } +#else +template // D head size, Q queries per block, C cache items per blocks +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ kqv, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { + bad_arch(); + } +#endif template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,