fix compile warnings

This commit is contained in:
Johannes Gäßler 2024-04-02 10:27:34 +02:00 committed by Georgi Gerganov
parent 3f777acf06
commit e1ecd3b129

View file

@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
// #pragma unroll
// for (int mask = 16; mask > 0; mask >>= 1) {
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
// }
// return x;
// #else
// GGML_UNUSED(x);
// NO_DEVICE_CODE;
// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
// }
#define FATTN_KQ_STRIDE 256
@ -472,9 +472,7 @@ static __global__ void flash_attn_ext_f16(
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
}
}
return;
}
} else {
#pragma unroll
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@ -489,6 +487,7 @@ static __global__ void flash_attn_ext_f16(
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
}
}
}
template<int D, int parallel_blocks> // D == head size
__launch_bounds__(D, 1)
@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} else {
cols_per_block = 8;
}
const int frag_m = cols_per_block == 8 ? 32 : 16;
constexpr int nwarps = 4;
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1);