fix compile warnings
This commit is contained in:
parent
3f777acf06
commit
e1ecd3b129
1 changed files with 23 additions and 25 deletions
|
@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
GGML_UNUSED(x);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
}
|
||||
// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
// #pragma unroll
|
||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||
// }
|
||||
// return x;
|
||||
// #else
|
||||
// GGML_UNUSED(x);
|
||||
// NO_DEVICE_CODE;
|
||||
// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
// }
|
||||
|
||||
#define FATTN_KQ_STRIDE 256
|
||||
|
||||
|
@ -472,9 +472,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
||||
|
@ -489,6 +487,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
__low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
|
@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
} else {
|
||||
cols_per_block = 8;
|
||||
}
|
||||
const int frag_m = cols_per_block == 8 ? 32 : 16;
|
||||
constexpr int nwarps = 4;
|
||||
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue