CUDA: use tensor cores for MMQ (#7676)
* CUDA: int8 tensor cores for MMQ (legacy quants) * fix out-of-bounds writes * __builtin_assume -> GGML_CUDA_ASSUME * fix writeback returning too early
This commit is contained in:
parent
af4ae502dd
commit
1f0dabda8d
7 changed files with 550 additions and 55 deletions
|
@ -139,6 +139,7 @@
|
|||
#define CC_PASCAL 600
|
||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define CC_VOLTA 700
|
||||
#define CC_TURING 750
|
||||
#define CC_AMPERE 800
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||
|
@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
|
|||
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
|
||||
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
#define FP16_AVAILABLE
|
||||
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
|
||||
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
#define INT8_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return cc >= CC_PASCAL && cc != 610;
|
||||
|
@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
|
|||
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
||||
}
|
||||
|
||||
static bool int8_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
|
@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
#pragma unroll
|
||||
|
@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
||||
return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue