CUDA: MMQ support for iq4_nl, iq4_xs (#8278)
This commit is contained in:
parent
0a423800ff
commit
8e558309dc
7 changed files with 226 additions and 80 deletions
|
@ -92,15 +92,17 @@ static constexpr __device__ int get_mmq_y_device() {
|
|||
|
||||
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
||||
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
||||
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
||||
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
||||
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
||||
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
||||
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
||||
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
||||
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 :
|
||||
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 :
|
||||
tile_x_sizes{0, 0, 0};
|
||||
}
|
||||
|
||||
|
@ -128,15 +130,17 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|||
|
||||
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
||||
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
||||
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
||||
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
||||
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
||||
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
||||
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
||||
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
||||
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 :
|
||||
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 :
|
||||
0;
|
||||
}
|
||||
|
||||
|
@ -185,9 +189,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
||||
#else
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
@ -348,9 +352,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
||||
#else
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
@ -509,8 +513,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
|
||||
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int ql = get_int_from_uint8(bxi->qs, kqsx);
|
||||
const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
|
||||
const int ql = get_int_b2(bxi->qs, kqsx);
|
||||
const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
|
||||
|
||||
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
||||
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
||||
|
@ -674,8 +678,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
|
||||
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||
const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
|
||||
const int ql = get_int_b4(bxi->qs, kqsx);
|
||||
const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
|
||||
|
||||
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
||||
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
||||
|
@ -839,9 +843,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
||||
#else
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
@ -984,7 +988,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
|
||||
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
|
||||
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_K; ++l) {
|
||||
|
@ -1166,8 +1170,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
|
||||
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
|
||||
const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
|
||||
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
||||
const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR3_K; ++l) {
|
||||
|
@ -1225,11 +1229,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
|
||||
const int ksc_low = ksc % (QI3_K/8);
|
||||
const int shift_low = 4 * (ksc / (QI3_K/8));
|
||||
const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
|
||||
const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
|
||||
|
||||
const int ksc_high = QI3_K/8;
|
||||
const int shift_high = 2 * ksc;
|
||||
const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
|
||||
const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
|
||||
|
||||
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
||||
|
||||
|
@ -1393,9 +1397,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
||||
#else
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
@ -1610,11 +1614,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
|
||||
const int ky = QR5_K*kqsx;
|
||||
|
||||
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||
const int ql = get_int_b4(bxi->qs, kqsx);
|
||||
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
||||
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
||||
|
||||
const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
|
||||
const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
|
||||
const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
|
||||
const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
|
||||
|
||||
|
@ -1832,11 +1836,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
|
||||
const int ky = QR6_K*kqsx;
|
||||
|
||||
const int ql = get_int_from_uint8(bxi->ql, kqsx);
|
||||
const int ql = get_int_b2(bxi->ql, kqsx);
|
||||
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
||||
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
||||
|
||||
const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
|
||||
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
|
||||
const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
|
||||
const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
|
||||
|
||||
|
@ -1883,9 +1887,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
|
||||
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
||||
#else
|
||||
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
|
||||
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
@ -2018,6 +2022,124 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kbx = threadIdx.x / QI4_NL;
|
||||
const int kqsx = threadIdx.x % QI4_NL;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
||||
const int2 v = get_int_from_table_16(aux_q4);
|
||||
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
|
||||
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
|
||||
int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kbx = 0; // threadIdx.x / QI4_XS
|
||||
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
||||
const int2 v = get_int_from_table_16(aux_q4);
|
||||
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
||||
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
||||
|
||||
const float d = __half2float(bxi->d);
|
||||
|
||||
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
||||
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
||||
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||
|
@ -2167,6 +2289,22 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
|||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
|
||||
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
|
||||
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
static bool mmq_need_sum(const ggml_type type_x) {
|
||||
switch (type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
@ -2184,6 +2322,8 @@ static bool mmq_need_sum(const ggml_type type_x) {
|
|||
case GGML_TYPE_Q5_K:
|
||||
return true;
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return false;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
@ -2608,6 +2748,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
|||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue