mmq_type_traits

This commit is contained in:
Johannes Gäßler 2024-06-05 15:44:25 +02:00
parent fe1c4bbff1
commit fd65ff31e9

View file

@ -4,10 +4,10 @@
#include <climits> #include <climits>
#include <cstdint> #include <cstdint>
typedef void (*load_tiles_cuda_t)( typedef void (*load_tiles_mmq_t)(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
typedef void (*vec_dot_q_mul_mat_cuda_t)( typedef void (*vec_dot_mmq_t)(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0); const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
@ -959,57 +959,88 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
// ------------------------------------------------------------------------------------------------------------------------------------- // -------------------------------------------------------------------------------------------------------------------------------------
static constexpr __device__ int get_need_sum(ggml_type type) { template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
return type == GGML_TYPE_Q4_0 || struct mmq_type_traits;
type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q5_1 ||
type == GGML_TYPE_Q4_K ||
type == GGML_TYPE_Q5_K;
}
template <int mmq_y, int nwarps, bool need_check> template <int mmq_x, int mmq_y, int nwarps, bool need_check>
static constexpr __device__ load_tiles_cuda_t get_load_tiles(ggml_type type) { struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
return type == GGML_TYPE_Q4_0 ? load_tiles_q4_0<mmq_y, nwarps, need_check> : static constexpr bool need_sum = true;
type == GGML_TYPE_Q4_1 ? load_tiles_q4_1<mmq_y, nwarps, need_check> : static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
type == GGML_TYPE_Q5_0 ? load_tiles_q5_0<mmq_y, nwarps, need_check> : static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
type == GGML_TYPE_Q5_1 ? load_tiles_q5_1<mmq_y, nwarps, need_check> : static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
type == GGML_TYPE_Q8_0 ? load_tiles_q8_0<mmq_y, nwarps, need_check> : };
type == GGML_TYPE_Q2_K ? load_tiles_q2_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q3_K ? load_tiles_q3_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q4_K ? load_tiles_q4_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q5_K ? load_tiles_q5_K<mmq_y, nwarps, need_check> :
type == GGML_TYPE_Q6_K ? load_tiles_q6_K<mmq_y, nwarps, need_check> :
nullptr;
}
static constexpr __device__ int get_vdr_mmq(ggml_type type) { template <int mmq_x, int mmq_y, int nwarps, bool need_check>
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMQ : struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMQ : static constexpr bool need_sum = true;
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMQ : static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMQ : static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMQ : static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMQ : };
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMQ :
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMQ :
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMQ :
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMQ :
0;
}
template <int mmq_x, int mmq_y, int nwarps> template <int mmq_x, int mmq_y, int nwarps, bool need_check>
static constexpr __device__ vec_dot_q_mul_mat_cuda_t get_vec_dot_mmq(ggml_type type) { struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : static constexpr bool need_sum = false;
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : };
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : template <int mmq_x, int mmq_y, int nwarps, bool need_check>
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : static constexpr bool need_sum = true;
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> : static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
nullptr; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
} static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
static constexpr bool need_sum = true;
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
static constexpr bool need_sum = true;
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
static constexpr bool need_sum = false;
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
};
template <ggml_type type, int mmq_x, int nwarps, bool need_check> template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
@ -1037,8 +1068,10 @@ static __global__ void mul_mat_q(
constexpr int qr = ggml_cuda_type_traits<type>::qr; constexpr int qr = ggml_cuda_type_traits<type>::qr;
constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device(mmq_x); constexpr int mmq_y = get_mmq_y_device(mmq_x);
constexpr bool need_sum = get_need_sum(type); constexpr bool need_sum = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
constexpr int vdr = get_vdr_mmq(type); constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type); constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
@ -1050,9 +1083,6 @@ static __global__ void mul_mat_q(
int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE] int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE]
half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1]; half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
constexpr load_tiles_cuda_t load_tiles = get_load_tiles<mmq_y, nwarps, need_check>(type);
constexpr vec_dot_q_mul_mat_cuda_t vec_dot = get_vec_dot_mmq<mmq_x, mmq_y, nwarps>(type);
const block_q8_1 * y = (const block_q8_1 *) yc; const block_q8_1 * y = (const block_q8_1 *) yc;
const int blocks_per_row_x = ne00 / qk; const int blocks_per_row_x = ne00 / qk;