From 5db21312502a1c05fd78cca4396f30c77a911806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 22 Jun 2024 19:00:00 +0200 Subject: [PATCH] simplify code, make functions constexpr --- ggml-cuda/common.cuh | 4 +- ggml-cuda/mmq.cuh | 99 ++++++++++++++++++++------------------------ 2 files changed, 46 insertions(+), 57 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 5bd24ebe5..5c8662535 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -643,7 +643,7 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; -static int get_mmq_x_max_host(const int cc) { +static constexpr int get_mmq_x_max_host(int cc) { #ifdef CUDA_USE_TENSOR_CORES return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; #else @@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) { } // Round rows to this value for --split-mode row: -static int get_mmq_y_host(const int cc) { +static constexpr int get_mmq_y_host(int cc) { return cc >= CC_VOLTA ? 128 : 64; } diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index fed9b6fb0..0f7f8ae51 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -67,26 +67,18 @@ static constexpr __device__ int get_mmq_y_device() { #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define GET_MMQ_DP4A_TXS_BODY \ - return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : \ - type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : \ - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : \ - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : \ - type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : \ - type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : \ - type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : \ - type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : \ - type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : \ - type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : \ - tile_x_sizes{0, 0, 0} - -static tile_x_sizes mmq_get_dp4a_tile_x_sizes_host(const ggml_type type, const int mmq_y) { - GET_MMQ_DP4A_TXS_BODY; -} - -template -static constexpr __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes_device(ggml_type type) { - GET_MMQ_DP4A_TXS_BODY; +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : + type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : + type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : + type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : + type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : + type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : + type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : + tile_x_sizes{0, 0, 0}; } #define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) @@ -111,21 +103,18 @@ static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); -#define MMQ_MMA_GET_TILE_X_K_BODY \ - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : \ - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : \ - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : \ - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : \ - type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : \ - type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : \ - type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : \ - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : \ - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : \ - type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \ - 0 - static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - MMQ_MMA_GET_TILE_X_K_BODY; + return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : + type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : + type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : + type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : + type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : + type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : + type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : + 0; } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) @@ -154,7 +143,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q4_0); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE @@ -204,7 +193,7 @@ template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q4_0); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; @@ -317,7 +306,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q4_1); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE @@ -367,7 +356,7 @@ template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q4_1); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; @@ -479,7 +468,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q5_0); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE @@ -548,7 +537,7 @@ template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q5_0); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; @@ -644,7 +633,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q5_1); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE @@ -711,7 +700,7 @@ template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q5_1); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; @@ -808,7 +797,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + WARP_SIZE); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q8_0); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE @@ -858,7 +847,7 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q8_0); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; @@ -954,7 +943,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q2_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // INT8_MMA_AVAILABLE @@ -1013,7 +1002,7 @@ template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q2_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; @@ -1135,7 +1124,7 @@ template static __device__ __forceinlin float * x_df = (float *) (x_qs + WARP_SIZE*2); int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q3_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); @@ -1233,7 +1222,7 @@ template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q3_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * x_sc = (const int *) x_df + txs.dm; @@ -1361,7 +1350,7 @@ template static __device__ __forceinlin half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q4_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); @@ -1437,7 +1426,7 @@ template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q4_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * x_sc = (const int *) x_dm + txs.dm; @@ -1578,7 +1567,7 @@ template static __device__ __forceinlin half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q5_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); @@ -1668,7 +1657,7 @@ template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q5_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * x_sc = (const int *) x_dm + txs.dm; @@ -1800,7 +1789,7 @@ template static __device__ __forceinlin float * x_df = (float *) (x_qs + WARP_SIZE*2); int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); #else - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q6_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); @@ -1882,7 +1871,7 @@ template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device(GGML_TYPE_Q6_K); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * x_sc = (const int *) x_df + txs.dm; @@ -2422,7 +2411,7 @@ struct mmq_args { template static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { - const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y); + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);