fix compile warnings
This commit is contained in:
parent
3f777acf06
commit
e1ecd3b129
1 changed files with 23 additions and 25 deletions
|
@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
#pragma unroll
|
// #pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
// for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
}
|
// }
|
||||||
return x;
|
// return x;
|
||||||
#else
|
// #else
|
||||||
GGML_UNUSED(x);
|
// GGML_UNUSED(x);
|
||||||
NO_DEVICE_CODE;
|
// NO_DEVICE_CODE;
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
}
|
// }
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE 256
|
#define FATTN_KQ_STRIDE 256
|
||||||
|
|
||||||
|
@ -472,9 +472,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
|
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
} else {
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
|
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
|
||||||
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
||||||
|
@ -488,6 +486,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(
|
dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(
|
||||||
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
|
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int D, int parallel_blocks> // D == head size
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
} else {
|
} else {
|
||||||
cols_per_block = 8;
|
cols_per_block = 8;
|
||||||
}
|
}
|
||||||
const int frag_m = cols_per_block == 8 ? 32 : 16;
|
|
||||||
constexpr int nwarps = 4;
|
constexpr int nwarps = 4;
|
||||||
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
||||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue