From fd65ff31e9fe3ea8e12087b35b0d4beaab05107f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 5 Jun 2024 15:44:25 +0200 Subject: [PATCH] mmq_type_traits --- ggml-cuda/mmq.cuh | 148 ++++++++++++++++++++++++++++------------------ 1 file changed, 89 insertions(+), 59 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index c9a6ced71..6744cce6d 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -4,10 +4,10 @@ #include #include -typedef void (*load_tiles_cuda_t)( +typedef void (*load_tiles_mmq_t)( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_q_mul_mat_cuda_t)( +typedef void (*vec_dot_mmq_t)( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0); @@ -959,57 +959,88 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( // ------------------------------------------------------------------------------------------------------------------------------------- -static constexpr __device__ int get_need_sum(ggml_type type) { - return type == GGML_TYPE_Q4_0 || - type == GGML_TYPE_Q4_1 || - type == GGML_TYPE_Q5_1 || - type == GGML_TYPE_Q4_K || - type == GGML_TYPE_Q5_K; -} +template +struct mmq_type_traits; -template -static constexpr __device__ load_tiles_cuda_t get_load_tiles(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? load_tiles_q4_0 : - type == GGML_TYPE_Q4_1 ? load_tiles_q4_1 : - type == GGML_TYPE_Q5_0 ? load_tiles_q5_0 : - type == GGML_TYPE_Q5_1 ? load_tiles_q5_1 : - type == GGML_TYPE_Q8_0 ? load_tiles_q8_0 : - type == GGML_TYPE_Q2_K ? load_tiles_q2_K : - type == GGML_TYPE_Q3_K ? load_tiles_q3_K : - type == GGML_TYPE_Q4_K ? load_tiles_q4_K : - type == GGML_TYPE_Q5_K ? load_tiles_q5_K : - type == GGML_TYPE_Q6_K ? load_tiles_q6_K : - nullptr; -} +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; +}; -static constexpr __device__ int get_vdr_mmq(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMQ : - type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMQ : - type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMQ : - type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMQ : - type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMQ : - type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMQ : - type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMQ : - type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMQ : - type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMQ : - type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMQ : - 0; -} +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; +}; -template -static constexpr __device__ vec_dot_q_mul_mat_cuda_t get_vec_dot_mmq(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1_mul_mat : - type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1_mul_mat : - type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1_mul_mat : - type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1_mul_mat : - type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1_mul_mat : - type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1_mul_mat : - type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1_mul_mat : - type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1_mul_mat : - type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1_mul_mat : - type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1_mul_mat : - nullptr; -} +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; +}; template #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) @@ -1033,12 +1064,14 @@ static __global__ void mul_mat_q( return; } - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qr = ggml_cuda_type_traits::qr; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int mmq_y = get_mmq_y_device(mmq_x); - constexpr bool need_sum = get_need_sum(type); - constexpr int vdr = get_vdr_mmq(type); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qr = ggml_cuda_type_traits::qr; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int mmq_y = get_mmq_y_device(mmq_x); + constexpr bool need_sum = mmq_type_traits::need_sum; + constexpr int vdr = mmq_type_traits::vdr; + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); @@ -1050,9 +1083,6 @@ static __global__ void mul_mat_q( int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE] half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1]; - constexpr load_tiles_cuda_t load_tiles = get_load_tiles(type); - constexpr vec_dot_q_mul_mat_cuda_t vec_dot = get_vec_dot_mmq(type); - const block_q8_1 * y = (const block_q8_1 *) yc; const int blocks_per_row_x = ne00 / qk;