fix compiler error

This commit is contained in:
FSSRepo 2024-01-29 13:15:33 -05:00
parent 2455a8d6c3
commit a1d5a12bc5

View file

@ -6134,6 +6134,7 @@ static __global__ void flash_attn_f32(
} }
} }
#if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc; typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
@ -6185,13 +6186,13 @@ static __global__ void flash_attn_ext_f16(
const half scale_h = __float2half(scale); const half scale_h = __float2half(scale);
extern __shared__ char data_flash_attn_shmem[]; extern __shared__ half __flash_attn_f16_shmem[];
// pq // pq
half * pq = (half *) (data_flash_attn_shmem + 0*D); half * pq = (half *) (__flash_attn_f16_shmem + 0*D);
half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); half2 * pq2 = (half2 *) (__flash_attn_f16_shmem + 0*D);
half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half * ps = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D);
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half2 * ps2 = (half2 *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D);
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); half * ss = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 2*D);
for (int i = 0; i < L2; ++i) { for (int i = 0; i < L2; ++i) {
// load heads from Q to shared memory // load heads from Q to shared memory
@ -6217,7 +6218,7 @@ static __global__ void flash_attn_ext_f16(
} }
__syncthreads(); __syncthreads();
#if 0
{ {
half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros?? half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros??
half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization
@ -6400,7 +6401,40 @@ static __global__ void flash_attn_ext_f16(
} }
} }
} }
#endif
} }
#else
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
static __global__ void flash_attn_ext_f16(
const char* __restrict__ q,
const char* __restrict__ k,
const char* __restrict__ v,
const char* __restrict__ mask,
float* __restrict__ kqv,
float scale,
int ne00,
int ne01,
int ne02,
int ne03,
int ne10,
int ne11,
int ne12,
int ne13,
int ne31,
int nb31,
int nb01,
int nb02,
int nb03,
int nb11,
int nb12,
int nb13,
int ne0,
int ne1,
int ne2,
int ne3) {
bad_arch();
}
#endif
template<int qk, int qr, dequantize_kernel_t dq> template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,