fix compiler error
This commit is contained in:
parent
2455a8d6c3
commit
a1d5a12bc5
1 changed files with 41 additions and 7 deletions
48
ggml-cuda.cu
48
ggml-cuda.cu
|
@ -6134,6 +6134,7 @@ static __global__ void flash_attn_f32(
|
|||
}
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||
|
@ -6185,13 +6186,13 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
const half scale_h = __float2half(scale);
|
||||
|
||||
extern __shared__ char data_flash_attn_shmem[];
|
||||
extern __shared__ half __flash_attn_f16_shmem[];
|
||||
// pq
|
||||
half * pq = (half *) (data_flash_attn_shmem + 0*D);
|
||||
half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D);
|
||||
half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
||||
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
||||
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D);
|
||||
half * pq = (half *) (__flash_attn_f16_shmem + 0*D);
|
||||
half2 * pq2 = (half2 *) (__flash_attn_f16_shmem + 0*D);
|
||||
half * ps = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D);
|
||||
half2 * ps2 = (half2 *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D);
|
||||
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 2*D);
|
||||
|
||||
for (int i = 0; i < L2; ++i) {
|
||||
// load heads from Q to shared memory
|
||||
|
@ -6217,7 +6218,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#if 0
|
||||
{
|
||||
half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros??
|
||||
half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization
|
||||
|
@ -6400,7 +6401,40 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char* __restrict__ q,
|
||||
const char* __restrict__ k,
|
||||
const char* __restrict__ v,
|
||||
const char* __restrict__ mask,
|
||||
float* __restrict__ kqv,
|
||||
float scale,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
int ne31,
|
||||
int nb31,
|
||||
int nb01,
|
||||
int nb02,
|
||||
int nb03,
|
||||
int nb11,
|
||||
int nb12,
|
||||
int nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3) {
|
||||
bad_arch();
|
||||
}
|
||||
#endif
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue