Try IQ4_NL with blocks of 64 - does not look good
This commit is contained in:
parent
a33e6a0d2a
commit
67264b3b30
8 changed files with 339 additions and 56 deletions
|
@ -36,7 +36,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||||
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },
|
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },
|
||||||
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", },
|
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", },
|
||||||
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
|
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
|
||||||
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.25 bpw non-linear quantization", },
|
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.5 bpw non-linear quantization", },
|
||||||
|
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.125 bpw non-linear quantization", },
|
||||||
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
|
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
|
||||||
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
|
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
|
||||||
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
|
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
|
||||||
|
|
87
ggml-cuda.cu
87
ggml-cuda.cu
|
@ -571,6 +571,16 @@ typedef struct {
|
||||||
} block_iq4_nl;
|
} block_iq4_nl;
|
||||||
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
||||||
|
|
||||||
|
#define QK4_XS 64
|
||||||
|
#define QR4_XS 2
|
||||||
|
#define QI4_XS (QK4_XS / (4*QR4_XS))
|
||||||
|
typedef struct {
|
||||||
|
half d;
|
||||||
|
uint8_t qs[QK4_XS/2];
|
||||||
|
} block_iq4_xs;
|
||||||
|
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + QK4_XS/2, "wrong iq4_xs block size/padding");
|
||||||
|
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||||
|
|
||||||
|
@ -2425,6 +2435,30 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
//const block_iq4_xs * x = (const block_iq4_xs *) vx + i*(QK_K/QK4_XS);
|
||||||
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
const int ib32 = i*(QK_K/32) + ib;
|
||||||
|
dst_t * y = yy + 32*ib32 + 4*il;
|
||||||
|
const uint8_t * q4 = x[ib32/(QK4_XS/32)].qs + 16*(ib32%(QK4_XS/32)) + 4*il;
|
||||||
|
const float d = (float)x[ib32/(QK4_XS/32)].d;
|
||||||
|
//dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
|
//const uint8_t * q4 = x->qs + 16*(ib%(QK4_XS/32)) + 4*il;
|
||||||
|
//const float d = (float)x->d;
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||||
|
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||||
|
|
||||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||||
|
@ -5302,6 +5336,41 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
||||||
return d * (sumi1 + sumi2);
|
return d * (sumi1 + sumi2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||||
|
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||||
|
|
||||||
|
const block_iq4_xs * bq = (const block_iq4_xs *) vbq;
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
|
const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
|
||||||
|
const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
|
||||||
|
|
||||||
|
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
||||||
|
|
||||||
|
int v1, v2;
|
||||||
|
int sumi1 = 0, sumi2 = 0;
|
||||||
|
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||||
|
const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
|
||||||
|
get_int_from_table_16(aux, values, v1, v2);
|
||||||
|
sumi1 = __dp4a(v1, q8[l+0], sumi1);
|
||||||
|
sumi2 = __dp4a(v2, q8[l+4], sumi2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
const uint8_t * q4 = bq->qs + 4*iqs;
|
||||||
|
const int8_t * q8 = bq8_1->qs + 4*iqs;
|
||||||
|
|
||||||
|
int sumi1 = 0, sumi2 = 0;
|
||||||
|
for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||||
|
sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf];
|
||||||
|
sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >> 4];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
const float d = (float)bq->d * __low2float(bq8_1->ds);
|
||||||
|
return d * (sumi1 + sumi2);
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||||
static __device__ __forceinline__ void mul_mat_q(
|
static __device__ __forceinline__ void mul_mat_q(
|
||||||
|
@ -7365,6 +7434,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
|
||||||
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = (k + QK_K - 1) / QK_K;
|
||||||
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename src_t, typename dst_t>
|
template <typename src_t, typename dst_t>
|
||||||
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||||
|
@ -7410,6 +7485,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||||
return dequantize_row_iq1_s_cuda;
|
return dequantize_row_iq1_s_cuda;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
return dequantize_row_iq4_nl_cuda;
|
return dequantize_row_iq4_nl_cuda;
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
return dequantize_row_iq4_xs_cuda;
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
return dequantize_row_iq3_s_cuda;
|
return dequantize_row_iq3_s_cuda;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
@ -7453,6 +7530,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
return dequantize_row_iq1_s_cuda;
|
return dequantize_row_iq1_s_cuda;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
return dequantize_row_iq4_nl_cuda;
|
return dequantize_row_iq4_nl_cuda;
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
return dequantize_row_iq4_xs_cuda;
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
return dequantize_row_iq3_s_cuda;
|
return dequantize_row_iq3_s_cuda;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
@ -9201,6 +9280,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
||||||
default:
|
default:
|
||||||
|
@ -9228,6 +9308,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
@ -9338,6 +9419,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
||||||
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
mul_mat_vec_q_cuda<QK4_XS, QI4_XS, block_iq4_xs, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_xs_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||||
|
break;
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||||
|
@ -12066,7 +12151,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
ggml_type a_type = a->type;
|
ggml_type a_type = a->type;
|
||||||
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
||||||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
|
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
|
||||||
a_type == GGML_TYPE_IQ2_S) {
|
a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
|
||||||
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
250
ggml-quants.c
250
ggml-quants.c
|
@ -4225,6 +4225,27 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) {
|
||||||
|
assert(k % QK4_XS == 0);
|
||||||
|
const int nb = k / QK4_XS;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
||||||
|
const uint8_t * qs = x[i].qs;
|
||||||
|
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d);
|
||||||
|
|
||||||
|
for (int iq = 0; iq < QK4_XS/32; ++iq) {
|
||||||
|
for (int j = 0; j < 16; ++j) {
|
||||||
|
y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf];
|
||||||
|
y[j+16] = d * kvalues_iq4nl[qs[j] >> 4];
|
||||||
|
}
|
||||||
|
y += 32;
|
||||||
|
qs += 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//===================================== Q8_K ==============================================
|
//===================================== Q8_K ==============================================
|
||||||
|
|
||||||
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
||||||
|
@ -10421,6 +10442,107 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_vec_dot_iq4_xs_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||||
|
assert(nrc == 1);
|
||||||
|
UNUSED(nrc);
|
||||||
|
UNUSED(bx);
|
||||||
|
UNUSED(by);
|
||||||
|
UNUSED(bs);
|
||||||
|
assert(n % QK4_XS == 0);
|
||||||
|
static_assert(QK8_0 == 32, "QK8_0 must be 32");
|
||||||
|
|
||||||
|
const block_iq4_xs * restrict x = vx;
|
||||||
|
const block_q8_0 * restrict y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK4_XS;
|
||||||
|
|
||||||
|
#if defined z__ARM_NEON
|
||||||
|
const int8x16_t values = vld1q_s8(kvalues_iq4nl);
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
|
uint8x16x2_t q4bits;
|
||||||
|
int8x16x4_t q4b;
|
||||||
|
int8x16x4_t q8b;
|
||||||
|
int32x4_t prod_1, prod_2;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (int ib = 0; ib < nb; ib += 2) {
|
||||||
|
|
||||||
|
q4bits.val[0] = vld1q_u8(x[ib+0].qs);
|
||||||
|
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
|
||||||
|
q8b.val[0] = vld1q_s8(y[ib+0].qs);
|
||||||
|
q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
|
||||||
|
q8b.val[2] = vld1q_s8(y[ib+1].qs);
|
||||||
|
q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
|
||||||
|
|
||||||
|
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
|
||||||
|
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
|
||||||
|
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
|
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
|
||||||
|
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
|
||||||
|
|
||||||
|
sumf +=
|
||||||
|
GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
|
||||||
|
GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sumf;
|
||||||
|
|
||||||
|
#elif defined z__AVX2__
|
||||||
|
|
||||||
|
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
|
||||||
|
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||||
|
const __m256i mone = _mm256_set1_epi16(1);
|
||||||
|
|
||||||
|
__m256 accum1 = _mm256_setzero_ps();
|
||||||
|
__m256 accum2 = _mm256_setzero_ps();
|
||||||
|
for (int ib = 0; ib < nb; ib += 2) {
|
||||||
|
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
|
||||||
|
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
|
||||||
|
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
|
||||||
|
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
|
||||||
|
const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
|
||||||
|
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
|
||||||
|
const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
|
||||||
|
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
|
||||||
|
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
||||||
|
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
||||||
|
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
|
||||||
|
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
|
||||||
|
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
|
||||||
|
_mm256_cvtepi32_ps(p_1), accum1);
|
||||||
|
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
|
||||||
|
_mm256_cvtepi32_ps(p_2), accum2);
|
||||||
|
|
||||||
|
y += 2;
|
||||||
|
x += 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
||||||
|
|
||||||
|
#else
|
||||||
|
float sumf = 0;
|
||||||
|
for (int ib = 0; ib < nb; ++ib) {
|
||||||
|
const float d4 = GGML_FP16_TO_FP32(x[ib].d);
|
||||||
|
const uint8_t * qs = x[ib].qs;
|
||||||
|
for (int iq = 0; iq < QK4_XS/32; ++iq) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(y[iq].d)*d4;
|
||||||
|
int sumi1 = 0, sumi2 = 0;
|
||||||
|
for (int j = 0; j < 16; ++j) {
|
||||||
|
sumi1 += y[iq].qs[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
|
||||||
|
sumi2 += y[iq].qs[j+16] * kvalues_iq4nl[qs[j] >> 4];
|
||||||
|
}
|
||||||
|
sumf += d * (sumi1 + sumi2);
|
||||||
|
qs += 16;
|
||||||
|
}
|
||||||
|
y += QK4_XS/32;
|
||||||
|
}
|
||||||
|
*s = sumf;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// ================================ IQ2 quantization =============================================
|
// ================================ IQ2 quantization =============================================
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -12017,7 +12139,7 @@ static inline int best_index_int8(int n, const int8_t * val, float x) {
|
||||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
|
static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT xb,
|
||||||
ggml_fp16_t * dh, uint8_t * q4,
|
ggml_fp16_t * dh, uint8_t * q4,
|
||||||
float * weight, uint8_t * L,
|
float * weight, uint8_t * L,
|
||||||
const int8_t * values,
|
const int8_t * values,
|
||||||
|
@ -12026,34 +12148,44 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
|
||||||
const int ntry = 7;
|
const int ntry = 7;
|
||||||
|
|
||||||
float sigma2 = 0;
|
float sigma2 = 0;
|
||||||
for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
|
for (int j = 0; j < block_size; ++j) sigma2 += xb[j]*xb[j];
|
||||||
sigma2 *= 2.f/QK4_NL;
|
sigma2 *= 2.f/block_size;
|
||||||
|
|
||||||
const int nb = QK4_NL/block_size;
|
memset(q4, 0, block_size/2);
|
||||||
|
|
||||||
memset(q4, 0, QK4_NL/2);
|
dh[0] = GGML_FP32_TO_FP16(0.f);
|
||||||
for (int ib = 0; ib < nb; ++ib) {
|
if (quant_weights) {
|
||||||
dh[ib] = GGML_FP32_TO_FP16(0.f);
|
for (int j = 0; j < block_size; ++j) weight[j] = quant_weights[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
||||||
const float * xb = x + ib*block_size;
|
} else {
|
||||||
if (quant_weights) {
|
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
|
||||||
const float * qw = quant_weights + ib*block_size;
|
}
|
||||||
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
float amax = 0, max = 0;
|
||||||
} else {
|
for (int j = 0; j < block_size; ++j) {
|
||||||
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
|
float ax = fabsf(xb[j]);
|
||||||
|
if (ax > amax) {
|
||||||
|
amax = ax; max = xb[j];
|
||||||
}
|
}
|
||||||
float amax = 0, max = 0;
|
}
|
||||||
for (int j = 0; j < block_size; ++j) {
|
if (!amax) {
|
||||||
float ax = fabsf(xb[j]);
|
return;
|
||||||
if (ax > amax) {
|
}
|
||||||
amax = ax; max = xb[j];
|
float d = -max/values[0];
|
||||||
}
|
float id = 1/d;
|
||||||
}
|
float sumqx = 0, sumq2 = 0;
|
||||||
if (!amax) {
|
for (int j = 0; j < block_size; ++j) {
|
||||||
continue;
|
float al = id*xb[j];
|
||||||
}
|
int l = best_index_int8(16, values, al);
|
||||||
float d = -max/values[0];
|
float q = values[l];
|
||||||
float id = 1/d;
|
float w = weight[j];
|
||||||
float sumqx = 0, sumq2 = 0;
|
sumqx += w*q*xb[j];
|
||||||
|
sumq2 += w*q*q;
|
||||||
|
}
|
||||||
|
float best_id = id;
|
||||||
|
d = sumqx/sumq2;
|
||||||
|
float best = d*sumqx;
|
||||||
|
for (int itry = -ntry; itry <= ntry; ++itry) {
|
||||||
|
id = (itry + values[0])/max;
|
||||||
|
sumqx = sumq2 = 0;
|
||||||
for (int j = 0; j < block_size; ++j) {
|
for (int j = 0; j < block_size; ++j) {
|
||||||
float al = id*xb[j];
|
float al = id*xb[j];
|
||||||
int l = best_index_int8(16, values, al);
|
int l = best_index_int8(16, values, al);
|
||||||
|
@ -12062,31 +12194,17 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE
|
||||||
sumqx += w*q*xb[j];
|
sumqx += w*q*xb[j];
|
||||||
sumq2 += w*q*q;
|
sumq2 += w*q*q;
|
||||||
}
|
}
|
||||||
float best_id = id;
|
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
||||||
d = sumqx/sumq2;
|
d = sumqx/sumq2; best = d * sumqx;
|
||||||
float best = d*sumqx;
|
best_id = id;
|
||||||
for (int itry = -ntry; itry <= ntry; ++itry) {
|
|
||||||
id = (itry + values[0])/max;
|
|
||||||
sumqx = sumq2 = 0;
|
|
||||||
for (int j = 0; j < block_size; ++j) {
|
|
||||||
float al = id*xb[j];
|
|
||||||
int l = best_index_int8(16, values, al);
|
|
||||||
float q = values[l];
|
|
||||||
float w = weight[j];
|
|
||||||
sumqx += w*q*xb[j];
|
|
||||||
sumq2 += w*q*q;
|
|
||||||
}
|
|
||||||
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
|
||||||
d = sumqx/sumq2; best = d * sumqx;
|
|
||||||
best_id = id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dh[ib] = GGML_FP32_TO_FP16(d);
|
|
||||||
for (int j = 0; j < block_size; ++j) {
|
|
||||||
L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < QK4_NL/32; ++i) {
|
dh[0] = GGML_FP32_TO_FP16(d);
|
||||||
|
for (int j = 0; j < block_size; ++j) {
|
||||||
|
L[j] = best_index_int8(16, values, best_id*xb[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < block_size/32; ++i) {
|
||||||
for (int j = 0; j < 16; ++j) {
|
for (int j = 0; j < 16; ++j) {
|
||||||
q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
|
q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
|
||||||
}
|
}
|
||||||
|
@ -12099,12 +12217,12 @@ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, i
|
||||||
int nblock = n_per_row/QK4_NL;
|
int nblock = n_per_row/QK4_NL;
|
||||||
char * qrow = (char *)dst;
|
char * qrow = (char *)dst;
|
||||||
uint8_t L[QK4_NL];
|
uint8_t L[QK4_NL];
|
||||||
float weight[32];
|
float weight[QK4_NL];
|
||||||
for (int row = 0; row < nrow; ++row) {
|
for (int row = 0; row < nrow; ++row) {
|
||||||
block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
|
block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
|
||||||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||||
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
|
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
|
||||||
quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
|
quantize_row_iq4_nl_impl(QK4_NL, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
|
||||||
}
|
}
|
||||||
src += n_per_row;
|
src += n_per_row;
|
||||||
qrow += nblock*sizeof(block_iq4_nl);
|
qrow += nblock*sizeof(block_iq4_nl);
|
||||||
|
@ -12123,6 +12241,36 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest
|
||||||
quantize_iq4_nl(x, y, 1, k, NULL, NULL);
|
quantize_iq4_nl(x, y, 1, k, NULL, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
||||||
|
(void)hist;
|
||||||
|
GGML_ASSERT(n_per_row%QK4_XS == 0);
|
||||||
|
int nblock = n_per_row/QK4_XS;
|
||||||
|
char * qrow = (char *)dst;
|
||||||
|
uint8_t L[QK4_XS];
|
||||||
|
float weight[QK4_XS];
|
||||||
|
for (int row = 0; row < nrow; ++row) {
|
||||||
|
block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
|
||||||
|
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||||
|
const float * qw = quant_weights ? quant_weights + QK4_XS*ibl : NULL;
|
||||||
|
quantize_row_iq4_nl_impl(QK4_XS, src + QK4_XS*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
|
||||||
|
}
|
||||||
|
src += n_per_row;
|
||||||
|
qrow += nblock*sizeof(block_iq4_xs);
|
||||||
|
}
|
||||||
|
return nrow * nblock * sizeof(block_iq4_xs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) {
|
||||||
|
assert(k % QK4_XS == 0);
|
||||||
|
block_iq4_xs * restrict y = vy;
|
||||||
|
quantize_row_iq4_xs_reference(x, y, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) {
|
||||||
|
assert(k % QK4_XS == 0);
|
||||||
|
quantize_iq4_xs(x, y, 1, k, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
// =============================== 2.5625 bpw
|
// =============================== 2.5625 bpw
|
||||||
|
|
||||||
static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
||||||
|
|
|
@ -230,6 +230,14 @@ typedef struct {
|
||||||
} block_iq4_nl;
|
} block_iq4_nl;
|
||||||
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
||||||
|
|
||||||
|
#define QK4_XS 64
|
||||||
|
typedef struct {
|
||||||
|
ggml_fp16_t d;
|
||||||
|
uint8_t qs[QK4_XS/2];
|
||||||
|
} block_iq4_xs;
|
||||||
|
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + QK4_XS/2, "wrong iq4_xs block size/padding");
|
||||||
|
static_assert(QK4_XS%32 == 0, "QK4_XS must be a multiple of 32");
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
@ -250,6 +258,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
|
||||||
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
|
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
|
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
|
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
|
||||||
|
void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
|
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k);
|
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k);
|
||||||
|
|
||||||
|
@ -268,6 +277,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
||||||
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||||
|
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||||
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||||
|
|
||||||
|
@ -291,6 +301,7 @@ void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_
|
||||||
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
|
void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
|
|
||||||
// Dot product
|
// Dot product
|
||||||
|
@ -311,6 +322,7 @@ void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
void ggml_vec_dot_iq4_xs_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -322,6 +334,7 @@ size_t quantize_iq2_s (const float * src, void * dst, int nrows, int n_per_row,
|
||||||
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||||
size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||||
size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||||
|
size_t quantize_iq4_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||||
size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||||
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||||
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||||
|
|
30
ggml.c
30
ggml.c
|
@ -726,6 +726,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
},
|
},
|
||||||
|
[GGML_TYPE_IQ4_XS] = {
|
||||||
|
.type_name = "iq4_xs",
|
||||||
|
.blck_size = QK4_XS,
|
||||||
|
.type_size = sizeof(block_iq4_xs),
|
||||||
|
.is_quantized = true,
|
||||||
|
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
||||||
|
.from_float = quantize_row_iq4_xs,
|
||||||
|
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
||||||
|
.vec_dot = ggml_vec_dot_iq4_xs_q8_0,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
|
.nrows = 1,
|
||||||
|
},
|
||||||
[GGML_TYPE_Q8_K] = {
|
[GGML_TYPE_Q8_K] = {
|
||||||
.type_name = "q8_K",
|
.type_name = "q8_K",
|
||||||
.blck_size = QK_K,
|
.blck_size = QK_K,
|
||||||
|
@ -2328,6 +2340,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||||
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
||||||
|
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
|
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
|
||||||
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
||||||
|
@ -7764,6 +7777,7 @@ static void ggml_compute_forward_add(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
{
|
{
|
||||||
|
@ -8045,6 +8059,7 @@ static void ggml_compute_forward_add1(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
{
|
{
|
||||||
|
@ -8171,6 +8186,7 @@ static void ggml_compute_forward_acc(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
default:
|
default:
|
||||||
|
@ -11071,6 +11087,7 @@ static void ggml_compute_forward_out_prod(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
{
|
{
|
||||||
|
@ -11261,6 +11278,7 @@ static void ggml_compute_forward_set(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
default:
|
default:
|
||||||
|
@ -11465,6 +11483,7 @@ static void ggml_compute_forward_get_rows(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
{
|
{
|
||||||
|
@ -12167,6 +12186,7 @@ static void ggml_compute_forward_alibi(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
case GGML_TYPE_Q8_K:
|
case GGML_TYPE_Q8_K:
|
||||||
|
@ -12252,6 +12272,7 @@ static void ggml_compute_forward_clamp(
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
case GGML_TYPE_Q8_K:
|
case GGML_TYPE_Q8_K:
|
||||||
|
@ -19817,6 +19838,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
||||||
result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
||||||
GGML_ASSERT(result == row_size * nrows);
|
GGML_ASSERT(result == row_size * nrows);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(start % QK4_NL == 0);
|
||||||
|
GGML_ASSERT(start % n_per_row == 0);
|
||||||
|
size_t start_row = start / n_per_row;
|
||||||
|
size_t row_size = ggml_row_size(type, n_per_row);
|
||||||
|
result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
||||||
|
GGML_ASSERT(result == row_size * nrows);
|
||||||
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
size_t elemsize = sizeof(ggml_fp16_t);
|
size_t elemsize = sizeof(ggml_fp16_t);
|
||||||
|
|
2
ggml.h
2
ggml.h
|
@ -352,6 +352,7 @@ extern "C" {
|
||||||
GGML_TYPE_IQ4_NL = 20,
|
GGML_TYPE_IQ4_NL = 20,
|
||||||
GGML_TYPE_IQ3_S = 21,
|
GGML_TYPE_IQ3_S = 21,
|
||||||
GGML_TYPE_IQ2_S = 22,
|
GGML_TYPE_IQ2_S = 22,
|
||||||
|
GGML_TYPE_IQ4_XS = 23,
|
||||||
GGML_TYPE_I8,
|
GGML_TYPE_I8,
|
||||||
GGML_TYPE_I16,
|
GGML_TYPE_I16,
|
||||||
GGML_TYPE_I32,
|
GGML_TYPE_I32,
|
||||||
|
@ -393,6 +394,7 @@ extern "C" {
|
||||||
GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
||||||
|
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||||
};
|
};
|
||||||
|
|
||||||
// available tensor operations:
|
// available tensor operations:
|
||||||
|
|
|
@ -2583,6 +2583,7 @@ struct llama_model_loader {
|
||||||
case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
|
case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
|
||||||
case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
|
case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
|
||||||
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
|
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
|
||||||
|
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
|
||||||
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
|
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -2940,6 +2941,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
|
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw";
|
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw";
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
|
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
|
||||||
|
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.125 bpw";
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
|
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
|
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
|
||||||
|
|
||||||
|
@ -10832,7 +10834,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
||||||
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL && qs.model.hparams.n_gqa() >= 4) {
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
|
||||||
new_type = GGML_TYPE_Q5_K;
|
new_type = GGML_TYPE_Q5_K;
|
||||||
}
|
}
|
||||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||||
|
@ -10901,7 +10903,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
||||||
if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL && !qs.has_imatrix) {
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
|
||||||
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q5_K;
|
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q5_K;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
||||||
|
@ -10922,7 +10924,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
|
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
|
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
|
||||||
new_type = GGML_TYPE_Q5_K;
|
new_type = GGML_TYPE_Q5_K;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -11039,6 +11041,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break;
|
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break;
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break;
|
case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break;
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break;
|
case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break;
|
||||||
|
case LLAMA_FTYPE_MOSTLY_IQ4_XS: quantized_type = GGML_TYPE_IQ4_XS; break;
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: quantized_type = GGML_TYPE_IQ3_S; break;
|
case LLAMA_FTYPE_MOSTLY_IQ3_S: quantized_type = GGML_TYPE_IQ3_S; break;
|
||||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: quantized_type = GGML_TYPE_IQ3_S; break;
|
case LLAMA_FTYPE_MOSTLY_IQ3_M: quantized_type = GGML_TYPE_IQ3_S; break;
|
||||||
|
|
||||||
|
|
1
llama.h
1
llama.h
|
@ -115,6 +115,7 @@ extern "C" {
|
||||||
LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||||
|
|
||||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue