mmq_type_traits
This commit is contained in:
parent
fe1c4bbff1
commit
fd65ff31e9
1 changed files with 89 additions and 59 deletions
|
@ -4,10 +4,10 @@
|
|||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
typedef void (*load_tiles_cuda_t)(
|
||||
typedef void (*load_tiles_mmq_t)(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
|
||||
typedef void (*vec_dot_q_mul_mat_cuda_t)(
|
||||
typedef void (*vec_dot_mmq_t)(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
|
||||
|
||||
|
@ -959,57 +959,88 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
|||
|
||||
// -------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
static constexpr __device__ int get_need_sum(ggml_type type) {
|
||||
return type == GGML_TYPE_Q4_0 ||
|
||||
type == GGML_TYPE_Q4_1 ||
|
||||
type == GGML_TYPE_Q5_1 ||
|
||||
type == GGML_TYPE_Q4_K ||
|
||||
type == GGML_TYPE_Q5_K;
|
||||
}
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
|
||||
struct mmq_type_traits;
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check>
|
||||
static constexpr __device__ load_tiles_cuda_t get_load_tiles(ggml_type type) {
|
||||
return type == GGML_TYPE_Q4_0 ? load_tiles_q4_0<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q4_1 ? load_tiles_q4_1<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q5_0 ? load_tiles_q5_0<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q5_1 ? load_tiles_q5_1<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q8_0 ? load_tiles_q8_0<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q2_K ? load_tiles_q2_K<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q3_K ? load_tiles_q3_K<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q4_K ? load_tiles_q4_K<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q5_K ? load_tiles_q5_K<mmq_y, nwarps, need_check> :
|
||||
type == GGML_TYPE_Q6_K ? load_tiles_q6_K<mmq_y, nwarps, need_check> :
|
||||
nullptr;
|
||||
}
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
||||
static constexpr bool need_sum = true;
|
||||
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
static constexpr __device__ int get_vdr_mmq(ggml_type type) {
|
||||
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMQ :
|
||||
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMQ :
|
||||
0;
|
||||
}
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
||||
static constexpr bool need_sum = true;
|
||||
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static constexpr __device__ vec_dot_q_mul_mat_cuda_t get_vec_dot_mmq(ggml_type type) {
|
||||
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps> :
|
||||
nullptr;
|
||||
}
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
||||
static constexpr bool need_sum = false;
|
||||
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
||||
static constexpr bool need_sum = true;
|
||||
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
||||
static constexpr bool need_sum = false;
|
||||
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
||||
static constexpr bool need_sum = false;
|
||||
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
||||
static constexpr bool need_sum = false;
|
||||
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
||||
static constexpr bool need_sum = true;
|
||||
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
||||
static constexpr bool need_sum = true;
|
||||
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
||||
static constexpr bool need_sum = false;
|
||||
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
@ -1033,12 +1064,14 @@ static __global__ void mul_mat_q(
|
|||
return;
|
||||
}
|
||||
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
||||
constexpr bool need_sum = get_need_sum(type);
|
||||
constexpr int vdr = get_vdr_mmq(type);
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
||||
constexpr bool need_sum = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
|
||||
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
||||
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
||||
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
|
||||
|
||||
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
|
||||
|
||||
|
@ -1050,9 +1083,6 @@ static __global__ void mul_mat_q(
|
|||
int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE]
|
||||
half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
|
||||
|
||||
constexpr load_tiles_cuda_t load_tiles = get_load_tiles<mmq_y, nwarps, need_check>(type);
|
||||
constexpr vec_dot_q_mul_mat_cuda_t vec_dot = get_vec_dot_mmq<mmq_x, mmq_y, nwarps>(type);
|
||||
|
||||
const block_q8_1 * y = (const block_q8_1 *) yc;
|
||||
|
||||
const int blocks_per_row_x = ne00 / qk;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue