simplify code, make functions constexpr
This commit is contained in:
parent
cab5981951
commit
5db2131250
2 changed files with 46 additions and 57 deletions
|
@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
|
|||
static constexpr int qi = QI3_S;
|
||||
};
|
||||
|
||||
static int get_mmq_x_max_host(const int cc) {
|
||||
static constexpr int get_mmq_x_max_host(int cc) {
|
||||
#ifdef CUDA_USE_TENSOR_CORES
|
||||
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
|
||||
#else
|
||||
|
@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
|
|||
}
|
||||
|
||||
// Round rows to this value for --split-mode row:
|
||||
static int get_mmq_y_host(const int cc) {
|
||||
static constexpr int get_mmq_y_host(int cc) {
|
||||
return cc >= CC_VOLTA ? 128 : 64;
|
||||
}
|
||||
|
||||
|
|
|
@ -67,26 +67,18 @@ static constexpr __device__ int get_mmq_y_device() {
|
|||
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
||||
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
||||
|
||||
#define GET_MMQ_DP4A_TXS_BODY \
|
||||
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : \
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : \
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : \
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : \
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : \
|
||||
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : \
|
||||
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : \
|
||||
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : \
|
||||
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : \
|
||||
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : \
|
||||
tile_x_sizes{0, 0, 0}
|
||||
|
||||
static tile_x_sizes mmq_get_dp4a_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
|
||||
GET_MMQ_DP4A_TXS_BODY;
|
||||
}
|
||||
|
||||
template <int mmq_y>
|
||||
static constexpr __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes_device(ggml_type type) {
|
||||
GET_MMQ_DP4A_TXS_BODY;
|
||||
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
||||
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
||||
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
||||
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
||||
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
||||
tile_x_sizes{0, 0, 0};
|
||||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
|
||||
|
@ -111,21 +103,18 @@ static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
|
|||
static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
||||
|
||||
#define MMQ_MMA_GET_TILE_X_K_BODY \
|
||||
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : \
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : \
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : \
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : \
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : \
|
||||
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : \
|
||||
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : \
|
||||
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : \
|
||||
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : \
|
||||
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
|
||||
0
|
||||
|
||||
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
MMQ_MMA_GET_TILE_X_K_BODY;
|
||||
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
||||
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
||||
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
||||
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
||||
0;
|
||||
}
|
||||
|
||||
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
||||
|
@ -154,7 +143,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
@ -204,7 +193,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + txs.qs;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
|
@ -317,7 +306,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
@ -367,7 +356,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
|
@ -479,7 +468,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
@ -548,7 +537,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + txs.qs;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
|
@ -644,7 +633,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
@ -711,7 +700,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
|
@ -808,7 +797,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_tile + WARP_SIZE);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
@ -858,7 +847,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + txs.qs;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
|
@ -954,7 +943,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
@ -1013,7 +1002,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
|
@ -1135,7 +1124,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
int * x_sc = (int *) (x_df + txs.dm);
|
||||
|
@ -1233,7 +1222,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + txs.qs;
|
||||
const int * x_sc = (const int *) x_df + txs.dm;
|
||||
|
@ -1361,7 +1350,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
|
||||
int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
||||
int * x_sc = (int *) (x_dm + txs.dm);
|
||||
|
@ -1437,7 +1426,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
||||
const int * x_sc = (const int *) x_dm + txs.dm;
|
||||
|
@ -1578,7 +1567,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
|
||||
int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
||||
int * x_sc = (int *) (x_dm + txs.dm);
|
||||
|
@ -1668,7 +1657,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
|
||||
const int * x_sc = (const int *) x_dm + txs.dm;
|
||||
|
@ -1800,7 +1789,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
int * x_sc = (int *) (x_df + txs.dm);
|
||||
|
@ -1882,7 +1871,7 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|||
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K);
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + txs.qs;
|
||||
const int * x_sc = (const int *) x_df + txs.dm;
|
||||
|
@ -2422,7 +2411,7 @@ struct mmq_args {
|
|||
|
||||
template<ggml_type type>
|
||||
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
|
||||
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y);
|
||||
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
||||
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
||||
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue