simplify code, make functions constexpr

This commit is contained in:
Johannes Gäßler 2024-06-22 19:00:00 +02:00
parent cab5981951
commit 5db2131250
2 changed files with 46 additions and 57 deletions

View file

@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S; static constexpr int qi = QI3_S;
}; };
static int get_mmq_x_max_host(const int cc) { static constexpr int get_mmq_x_max_host(int cc) {
#ifdef CUDA_USE_TENSOR_CORES #ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else #else
@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
} }
// Round rows to this value for --split-mode row: // Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc) { static constexpr int get_mmq_y_host(int cc) {
return cc >= CC_VOLTA ? 128 : 64; return cc >= CC_VOLTA ? 128 : 64;
} }

View file

@ -67,26 +67,18 @@ static constexpr __device__ int get_mmq_y_device() {
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
#define GET_MMQ_DP4A_TXS_BODY \ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : \ return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : \ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : \ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : \ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : \ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : \ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : \ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : \ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : \ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : \ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
tile_x_sizes{0, 0, 0} tile_x_sizes{0, 0, 0};
static tile_x_sizes mmq_get_dp4a_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
GET_MMQ_DP4A_TXS_BODY;
}
template <int mmq_y>
static constexpr __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes_device(ggml_type type) {
GET_MMQ_DP4A_TXS_BODY;
} }
#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) #define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
@ -111,21 +103,18 @@ static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
#define MMQ_MMA_GET_TILE_X_K_BODY \
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : \
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : \
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : \
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : \
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : \
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : \
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : \
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : \
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : \
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
0
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
MMQ_MMA_GET_TILE_X_K_BODY; return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
0;
} }
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
@ -154,7 +143,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE); float * x_df = (float *) (x_qs + WARP_SIZE);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs); float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
@ -204,7 +193,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs; const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4; const int * y_qs = (const int *) y + 4;
@ -317,7 +306,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs); half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
@ -367,7 +356,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs; const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4; const int * y_qs = (const int *) y + 4;
@ -479,7 +468,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2); float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs); float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
@ -548,7 +537,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs; const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4; const int * y_qs = (const int *) y + 4;
@ -644,7 +633,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs); half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
@ -711,7 +700,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs; const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4; const int * y_qs = (const int *) y + 4;
@ -808,7 +797,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_tile + WARP_SIZE); float * x_df = (float *) (x_tile + WARP_SIZE);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs); float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
@ -858,7 +847,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs; const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4; const int * y_qs = (const int *) y + 4;
@ -954,7 +943,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs); half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
@ -1013,7 +1002,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs; const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4; const int * y_qs = (const int *) y + 4;
@ -1135,7 +1124,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + WARP_SIZE*2); float * x_df = (float *) (x_qs + WARP_SIZE*2);
int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K); int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs); float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm); int * x_sc = (int *) (x_df + txs.dm);
@ -1233,7 +1222,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs; const float * x_df = (const float *) x_qs + txs.qs;
const int * x_sc = (const int *) x_df + txs.dm; const int * x_sc = (const int *) x_df + txs.dm;
@ -1361,7 +1350,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K); int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs); half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm); int * x_sc = (int *) (x_dm + txs.dm);
@ -1437,7 +1426,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs; const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm; const int * x_sc = (const int *) x_dm + txs.dm;
@ -1578,7 +1567,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K); int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs); half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm); int * x_sc = (int *) (x_dm + txs.dm);
@ -1668,7 +1657,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs; const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm; const int * x_sc = (const int *) x_dm + txs.dm;
@ -1800,7 +1789,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + WARP_SIZE*2); float * x_df = (float *) (x_qs + WARP_SIZE*2);
int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
#else #else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs); float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm); int * x_sc = (int *) (x_df + txs.dm);
@ -1882,7 +1871,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs; const float * x_df = (const float *) x_qs + txs.qs;
const int * x_sc = (const int *) x_df + txs.dm; const int * x_sc = (const int *) x_df + txs.dm;
@ -2422,7 +2411,7 @@ struct mmq_args {
template<ggml_type type> template<ggml_type type>
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y); const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);