fix compile warnings

This commit is contained in:
Johannes Gäßler 2024-04-02 10:27:34 +02:00 committed by Georgi Gerganov
parent 3f777acf06
commit e1ecd3b129

View file

@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
} }
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { // static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX // #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll // #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) { // for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); // x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
} // }
return x; // return x;
#else // #else
GGML_UNUSED(x); // GGML_UNUSED(x);
NO_DEVICE_CODE; // NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX // #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
} // }
#define FATTN_KQ_STRIDE 256 #define FATTN_KQ_STRIDE 256
@ -472,21 +472,20 @@ static __global__ void flash_attn_ext_f16(
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
} }
} }
return; } else {
}
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
if (i0 + nwarps*WARP_SIZE > D && i >= D) { if (i0 + nwarps*WARP_SIZE > D && i >= D) {
return; return;
}
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i];
} }
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i];
}
if (threadIdx.y == 0 && threadIdx.x == 0) { if (threadIdx.y == 0 && threadIdx.x == 0) {
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
}
} }
} }
@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} else { } else {
cols_per_block = 8; cols_per_block = 8;
} }
const int frag_m = cols_per_block == 8 ? 32 : 16;
constexpr int nwarps = 4; constexpr int nwarps = 4;
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 block_dim(WARP_SIZE, nwarps, 1);