From 59937e45a3f2445b4d990bd26a5b31ad70999a20 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 30 Sep 2023 14:28:27 +0200 Subject: [PATCH] rename CC_TURING to CC_VOLTA --- ggml-cuda.cu | 89 ++++++++++++++++++++++++++-------------------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 153bd1fb9..2fd15957f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -80,9 +80,9 @@ #include "ggml.h" #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define CC_TURING 700 +#define CC_VOLTA 700 #define CC_OFFSET_AMD 1000000 -#define CC_RDNA2 CC_OFFSET_AMD + 1030 +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -3553,7 +3553,7 @@ template static __global__ void load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q4_0_AMPERE; const int mmq_y = MMQ_Y_Q4_0_AMPERE; const int nwarps = NWARPS_Q4_0_AMPERE; @@ -3573,7 +3573,7 @@ template static __global__ void #else (void) vec_dot_q4_0_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q4_1_RDNA2 64 @@ -3594,9 +3594,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q4_1( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -3616,7 +3616,7 @@ template static __global__ void load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q4_1_AMPERE; const int mmq_y = MMQ_Y_Q4_1_AMPERE; const int nwarps = NWARPS_Q4_1_AMPERE; @@ -3636,7 +3636,7 @@ template static __global__ void #else (void) vec_dot_q4_1_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q5_0_RDNA2 64 @@ -3677,7 +3677,7 @@ template static __global__ void load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q5_0_AMPERE; const int mmq_y = MMQ_Y_Q5_0_AMPERE; const int nwarps = NWARPS_Q5_0_AMPERE; @@ -3697,7 +3697,7 @@ template static __global__ void #else (void) vec_dot_q5_0_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q5_1_RDNA2 64 @@ -3738,7 +3738,7 @@ mul_mat_q5_1( load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q5_1_AMPERE; const int mmq_y = MMQ_Y_Q5_1_AMPERE; const int nwarps = NWARPS_Q5_1_AMPERE; @@ -3758,7 +3758,7 @@ mul_mat_q5_1( #else (void) vec_dot_q5_1_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q8_0_RDNA2 64 @@ -3799,7 +3799,7 @@ template static __global__ void load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q8_0_AMPERE; const int mmq_y = MMQ_Y_Q8_0_AMPERE; const int nwarps = NWARPS_Q8_0_AMPERE; @@ -3819,7 +3819,7 @@ template static __global__ void #else (void) vec_dot_q8_0_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q2_K_RDNA2 64 @@ -3860,7 +3860,7 @@ mul_mat_q2_K( load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q2_K_AMPERE; const int mmq_y = MMQ_Y_Q2_K_AMPERE; const int nwarps = NWARPS_Q2_K_AMPERE; @@ -3880,7 +3880,7 @@ mul_mat_q2_K( #else (void) vec_dot_q2_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q3_K_RDNA2 128 @@ -3901,9 +3901,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q3_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -3923,7 +3923,7 @@ template static __global__ void load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q3_K_AMPERE; const int mmq_y = MMQ_Y_Q3_K_AMPERE; const int nwarps = NWARPS_Q3_K_AMPERE; @@ -3943,7 +3943,7 @@ template static __global__ void #else (void) vec_dot_q3_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q4_K_RDNA2 64 @@ -3964,9 +3964,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q4_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -3986,7 +3986,7 @@ template static __global__ void load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q4_K_AMPERE; const int mmq_y = MMQ_Y_Q4_K_AMPERE; const int nwarps = NWARPS_Q4_K_AMPERE; @@ -4006,7 +4006,7 @@ template static __global__ void #else (void) vec_dot_q4_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q5_K_RDNA2 64 @@ -4047,7 +4047,7 @@ mul_mat_q5_K( load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q5_K_AMPERE; const int mmq_y = MMQ_Y_Q5_K_AMPERE; const int nwarps = NWARPS_Q5_K_AMPERE; @@ -4067,7 +4067,7 @@ mul_mat_q5_K( #else (void) vec_dot_q5_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q6_K_RDNA2 64 @@ -4088,9 +4088,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q6_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -4110,7 +4110,7 @@ template static __global__ void load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q6_K_AMPERE; const int mmq_y = MMQ_Y_Q6_K_AMPERE; const int nwarps = NWARPS_Q6_K_AMPERE; @@ -4130,7 +4130,7 @@ template static __global__ void #else (void) vec_dot_q6_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } template @@ -4674,6 +4674,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu dequantize_block_q5_K<<>>(vx, y); #endif } + template static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; @@ -4955,7 +4956,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( mmq_x = MMQ_X_Q4_0_RDNA1; mmq_y = MMQ_Y_Q4_0_RDNA1; nwarps = NWARPS_Q4_0_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q4_0_AMPERE; mmq_y = MMQ_Y_Q4_0_AMPERE; nwarps = NWARPS_Q4_0_AMPERE; @@ -5000,7 +5001,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( mmq_x = MMQ_X_Q4_1_RDNA1; mmq_y = MMQ_Y_Q4_1_RDNA1; nwarps = NWARPS_Q4_1_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q4_1_AMPERE; mmq_y = MMQ_Y_Q4_1_AMPERE; nwarps = NWARPS_Q4_1_AMPERE; @@ -5045,7 +5046,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( mmq_x = MMQ_X_Q5_0_RDNA1; mmq_y = MMQ_Y_Q5_0_RDNA1; nwarps = NWARPS_Q5_0_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q5_0_AMPERE; mmq_y = MMQ_Y_Q5_0_AMPERE; nwarps = NWARPS_Q5_0_AMPERE; @@ -5090,7 +5091,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( mmq_x = MMQ_X_Q5_1_RDNA1; mmq_y = MMQ_Y_Q5_1_RDNA1; nwarps = NWARPS_Q5_1_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q5_1_AMPERE; mmq_y = MMQ_Y_Q5_1_AMPERE; nwarps = NWARPS_Q5_1_AMPERE; @@ -5135,7 +5136,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( mmq_x = MMQ_X_Q8_0_RDNA1; mmq_y = MMQ_Y_Q8_0_RDNA1; nwarps = NWARPS_Q8_0_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q8_0_AMPERE; mmq_y = MMQ_Y_Q8_0_AMPERE; nwarps = NWARPS_Q8_0_AMPERE; @@ -5180,7 +5181,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( mmq_x = MMQ_X_Q2_K_RDNA1; mmq_y = MMQ_Y_Q2_K_RDNA1; nwarps = NWARPS_Q2_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q2_K_AMPERE; mmq_y = MMQ_Y_Q2_K_AMPERE; nwarps = NWARPS_Q2_K_AMPERE; @@ -5227,7 +5228,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( mmq_x = MMQ_X_Q3_K_RDNA1; mmq_y = MMQ_Y_Q3_K_RDNA1; nwarps = NWARPS_Q3_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q3_K_AMPERE; mmq_y = MMQ_Y_Q3_K_AMPERE; nwarps = NWARPS_Q3_K_AMPERE; @@ -5273,7 +5274,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( mmq_x = MMQ_X_Q4_K_RDNA1; mmq_y = MMQ_Y_Q4_K_RDNA1; nwarps = NWARPS_Q4_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q4_K_AMPERE; mmq_y = MMQ_Y_Q4_K_AMPERE; nwarps = NWARPS_Q4_K_AMPERE; @@ -5318,7 +5319,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( mmq_x = MMQ_X_Q5_K_RDNA1; mmq_y = MMQ_Y_Q5_K_RDNA1; nwarps = NWARPS_Q5_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q5_K_AMPERE; mmq_y = MMQ_Y_Q5_K_AMPERE; nwarps = NWARPS_Q5_K_AMPERE; @@ -5363,7 +5364,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( mmq_x = MMQ_X_Q6_K_RDNA1; mmq_y = MMQ_Y_Q6_K_RDNA1; nwarps = NWARPS_Q6_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q6_K_AMPERE; mmq_y = MMQ_Y_Q6_K_AMPERE; nwarps = NWARPS_Q6_K_AMPERE; @@ -5941,7 +5942,7 @@ static int64_t get_row_rounding(ggml_type type) { switch(type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - return max_compute_capability >= CC_TURING ? 128 : 64; + return max_compute_capability >= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -5952,7 +5953,7 @@ static int64_t get_row_rounding(ggml_type type) { case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - return max_compute_capability >= CC_TURING ? 128 : 64; + return max_compute_capability >= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: return 64; default: @@ -6117,7 +6118,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const int compute_capability = g_compute_capabilities[id]; - if (compute_capability >= CC_TURING && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && ldc == row_diff) { + if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && ldc == row_diff) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; size_t src0_as = 0; @@ -6128,7 +6129,7 @@ inline void ggml_cuda_op_mul_mat_cublas( src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); } - const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (half *) src0_dd_i : src0_as_f16; + const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; half * src1_as_f16 = nullptr; size_t src1_as = 0;