From 5691fecd06f35de3dc1cbbcfb7d83233a1030aa9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 20 Feb 2024 12:26:17 +0200 Subject: [PATCH] Resurrecting iq3_xs After all the experimentation, nothing was better than this. --- examples/quantize/quantize.cpp | 1 + ggml-cuda.cu | 159 ++++++++++++- ggml-quants.c | 398 ++++++++++++++++++++++++++++++--- ggml-quants.h | 13 ++ ggml.c | 31 +++ ggml.h | 2 + llama.cpp | 11 +- llama.h | 1 + 8 files changed, 580 insertions(+), 36 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 37520857f..8ac31d141 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -27,6 +27,7 @@ static const std::vector QUANT_OPTIONS = { { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, + { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.31 bpw quantization", }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, { "Q3_K_XS",LLAMA_FTYPE_MOSTLY_Q3_K_XS,"3-bit extra small quantization" , }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b0e454e02..dbf35c8df 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -518,6 +518,15 @@ typedef struct { } block_iq3_xxs; static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); +#define QR3_XS 8 +#define QI3_XS (QK_K / (4*QR3_XS)) +typedef struct { + half d; + uint8_t qs[3*(QK_K/8)]; + uint8_t qh[QK_K/32]; +} block_iq3_xs; +static_assert(sizeof(block_iq3_xs) == sizeof(ggml_fp16_t) + 3*(QK_K/8) + QK_K/32, "wrong iq3_xs block size/padding"); + #define QR1_S 8 #define QI1_S (QK_K / (4*QR1_S)) typedef struct { @@ -1700,6 +1709,74 @@ static const __device__ uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; +static const __device__ uint32_t iq3xs_grid[512] = { + 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, + 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, + 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, + 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, + 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, + 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, + 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, + 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, + 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, + 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, + 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, + 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, + 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, + 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, + 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, + 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, + 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, + 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, + 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, + 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, + 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, + 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, + 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, + 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, + 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, + 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, + 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, + 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, + 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, + 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, + 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, + 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, + 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, + 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, + 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, + 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, + 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, + 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, + 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, + 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, + 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, + 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, + 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, + 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, + 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, + 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, + 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, + 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, + 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, + 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, + 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, + 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, + 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, + 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, + 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, + 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, + 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, + 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, + 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, + 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, + 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, + 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, + 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, + 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +}; + + static const __device__ uint64_t iq1s_grid[512] = { 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, @@ -1973,6 +2050,34 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } +template +static __global__ void dequantize_block_iq3_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq3_xs * x = (const block_iq3_xs *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * q3 = x[i].qs + 8*ib; + const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (q3[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (q3[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } +#else + assert(false); +#endif + +} + template static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -4717,6 +4822,42 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #endif } +static __device__ __forceinline__ float vec_dot_iq3_xs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +#if QK_K == 256 + const block_iq3_xs * bq2 = (const block_iq3_xs *) vbq; + + const int ib32 = iqs; + const uint8_t * q3 = bq2->qs + 8*ib32; + const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = gas[0] | (gas[1] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid1 = iq3xs_grid + (q3[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); + const uint32_t * grid2 = iq3xs_grid + (q3[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); + const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]); + sumi = __dp4a(grid_l, *((int *)q8+0), sumi); + sumi = __dp4a(grid_h, *((int *)q8+1), sumi); + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f; + return d * sumi; +#else + assert(false); + return 0.f; +#endif +#else + assert(false); + return 0.f; +#endif +} + + static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #if QK_K == 256 @@ -6849,6 +6990,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, dequantize_block_iq3_xxs<<>>(vx, y); } +template +static void dequantize_row_iq3_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq3_xs<<>>(vx, y); +} + template static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; @@ -6904,6 +7051,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_IQ3_XS: + return dequantize_row_iq3_xs_cuda; case GGML_TYPE_F32: return convert_unary_cuda; default: @@ -6943,6 +7092,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_IQ3_XS: + return dequantize_row_iq3_xs_cuda; case GGML_TYPE_F16: return convert_unary_cuda; default: @@ -8688,6 +8839,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_RDNA2 ? 128 : 64; default: GGML_ASSERT(false); @@ -8713,6 +8865,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: return 64; @@ -8818,6 +8971,10 @@ static void ggml_cuda_op_mul_mat_vec_q( mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ3_XS: + mul_mat_vec_q_cuda + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; default: GGML_ASSERT(false); break; @@ -11541,7 +11698,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } ggml_type a_type = a->type; if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) { + a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_XS) { if (b->ne[1] == 1 && ggml_nrows(b) > 1) { return false; } diff --git a/ggml-quants.c b/ggml-quants.c index e3380b049..e10471a60 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3505,6 +3505,73 @@ static const uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; +static const uint32_t iq3xs_grid[512] = { + 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, + 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, + 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, + 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, + 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, + 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, + 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, + 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, + 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, + 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, + 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, + 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, + 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, + 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, + 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, + 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, + 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, + 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, + 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, + 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, + 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, + 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, + 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, + 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, + 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, + 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, + 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, + 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, + 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, + 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, + 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, + 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, + 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, + 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, + 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, + 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, + 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, + 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, + 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, + 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, + 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, + 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, + 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, + 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, + 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, + 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, + 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, + 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, + 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, + 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, + 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, + 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, + 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, + 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, + 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, + 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, + 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, + 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, + 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, + 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, + 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, + 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, + 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, + 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +}; + #define NGRID_IQ2XXS 512 static const uint64_t iq1s_grid[NGRID_IQ2XXS] = { 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, @@ -3736,6 +3803,39 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y } } +// ====================== 3.3125 bpw (de)-quantization + +void dequantize_row_iq3_xs(const block_iq3_xs * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint32_t aux32; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * scales_and_signs = qs + QK_K/4; + const uint8_t * qh = x[i].qh; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t)); + const float db = d * (0.5f + (aux32 >> 28)) * 0.5f; + for (int l = 0; l < 4; ++l) { + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qs += 8; + } + } +} + // ====================== 1.5625 bpw (de)-quantization void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) { @@ -9327,6 +9427,142 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void #endif } +// TODO +void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); + const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; + const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; + const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; + const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; + q3 += 16; + q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); + q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); + q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); + q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.5f * sumf; + +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + +#else + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} + + #ifdef __AVX2__ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { const __m256i ax = _mm256_sign_epi8(x, x); @@ -10240,14 +10476,15 @@ typedef struct { uint16_t * neighbours; } iq3_entry_t; -static iq3_entry_t iq3_data[1] = { +static iq3_entry_t iq3_data[2] = { + {NULL, NULL, NULL}, {NULL, NULL, NULL}, }; static inline int iq3_data_index(int grid_size) { (void)grid_size; - GGML_ASSERT(grid_size == 256); - return 0; + GGML_ASSERT(grid_size == 256 || grid_size == 512); + return grid_size == 256 ? 0 : 1; } static int iq3_compare_func(const void * left, const void * right) { @@ -10279,9 +10516,44 @@ void iq3xs_init_impl(int grid_size) { 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, }; + static const uint16_t kgrid_512[512] = { + 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, + 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77, + 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142, + 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, + 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, + 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, + 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514, + 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576, + 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653, + 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, + 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, + 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, + 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, + 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, + 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, + 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, + 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, + 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, + 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, + 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, + 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, + 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, + 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, + 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, + 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, + 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, + 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, + 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, + 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, + 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, + 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, + 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, + }; + const int kmap_size = 4096; const int nwant = 2; - const uint16_t * kgrid = kgrid_256; + const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; uint32_t * kgrid_q3xs; int * kmap_q3xs; uint16_t * kneighbors_q3xs; @@ -10378,7 +10650,7 @@ void iq3xs_init_impl(int grid_size) { } void iq3xs_free_impl(int grid_size) { - GGML_ASSERT(grid_size == 256); + GGML_ASSERT(grid_size == 256 || grid_size == 512); const int gindex = iq3_data_index(grid_size); if (iq3_data[gindex].grid) { free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL; @@ -10411,9 +10683,10 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } -static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { +static void quantize_row_iq3_xs_impl(int grid_size, const float * restrict x, void * restrict vy, int n, + const float * restrict quant_weights) { - const int gindex = iq3_data_index(256); + const int gindex = iq3_data_index(grid_size); const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; const int * kmap_q3xs = iq3_data[gindex].map; @@ -10429,7 +10702,21 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict const int nbl = n/256; - block_iq3_xxs * y = vy; + ggml_fp16_t * dh; + uint8_t * qs; + int block_size; + if (grid_size == 256) { + block_iq3_xxs * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_xxs); + } else { + block_iq3_xs * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_xs); + } + int quant_size = block_size - sizeof(ggml_fp16_t); float scales[QK_K/32]; float weight[32]; @@ -10440,20 +10727,21 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict bool is_on_grid[8]; bool is_on_grid_aux[8]; uint8_t block_signs[8]; - uint8_t q3[3*(QK_K/8)]; + uint8_t q3[3*(QK_K/8)+QK_K/32]; uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4); + uint8_t * qh = q3 + 3*(QK_K/8); for (int ibl = 0; ibl < nbl; ++ibl) { - y[ibl].d = GGML_FP32_TO_FP16(0.f); - memset(q3, 0, 3*QK_K/8); + dh[0] = GGML_FP32_TO_FP16(0.f); + memset(q3, 0, 3*QK_K/8+QK_K/32); float max_scale = 0; const float * xbl = x + QK_K*ibl; float sumx2 = 0; for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; - float sigma2 = sumx2/QK_K; + float sigma2 = 2*sumx2/QK_K; for (int ib = 0; ib < QK_K/32; ++ib) { const float * xb = xbl + 32*ib; @@ -10571,7 +10859,13 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict printf("\n"); GGML_ASSERT(false); } - q3[8*ib+k] = grid_index; + if (grid_size == 256) { + q3[8*ib+k] = grid_index; + } else { + q3[8*ib+k] = grid_index & 255; + qh[ib] |= ((grid_index >> 8) << k); + } + } scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21); GGML_ASSERT(scale >= 0); @@ -10580,14 +10874,16 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict } if (!max_scale) { - memset(y[ibl].qs, 0, 3*QK_K/8); + memset(qs, 0, quant_size); + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; continue; } float d = max_scale/31; - y[ibl].d = GGML_FP32_TO_FP16(d); + dh[0] = GGML_FP32_TO_FP16(d); float id = 1/d; - float sumqx = 0, sumq2 = 0; + //float sumqx = 0, sumq2 = 0; for (int ib = 0; ib < QK_K/32; ++ib) { int l = nearest_int(0.5f*(id*scales[ib]-1)); l = MAX(0, MIN(15, l)); @@ -10600,21 +10896,24 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict } else { for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i]; } - const float db = 0.25f * d * (1 + 2*l); + uint8_t h = 0; + if (grid_size == 512) { h = qh[ib]; qh[ib] = 0; } + const float db = d * (1 + 2*l); for (int k = 0; k < 8; ++k) { const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2); const float * xk = xb + 4*k; const float * wk = weight + 4*k; //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]); - const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]); - float best_mse = 0; int best_index = q3[8*ib+k]; + int idx = q3[8*ib+k]; + if (grid_size == 512) idx |= ((h << (8-k)) & 256); + const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + idx); + float best_mse = 0; int best_index = idx; for (int j = 0; j < 4; ++j) { float diff = db * grid[j] * signs[j] - xk[j]; best_mse += wk[j] * diff * diff; } - for (int idx = 0; idx < 256; ++idx) { - //grid = (const uint8_t *)(kgrid_q3xs + idx); - grid = (const uint8_t *)(iq3xxs_grid + idx); + for (idx = 0; idx < grid_size; ++idx) { + grid = (const uint8_t *)(kgrid_q3xs + idx); float mse = 0; for (int j = 0; j < 4; ++j) { float diff = db * grid[j] * signs[j] - xk[j]; @@ -10624,19 +10923,27 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict best_mse = mse; best_index = idx; } } - q3[8*ib+k] = best_index; - //grid = (const uint8_t *)(kgrid_q3xs + best_index); - grid = (const uint8_t *)(iq3xxs_grid + best_index); - for (int j = 0; j < 4; ++j) { - float q = db * grid[j] * signs[j]; - sumqx += wk[j] * q * xk[j]; - sumq2 += wk[j] * q * q; + if (grid_size == 256) { + q3[8*ib+k] = best_index; + } else { + q3[8*ib+k] = best_index & 255; + qh[ib] |= ((best_index >> 8) << k); } + //grid = (const uint8_t *)(kgrid_q3xs + best_index); + //for (int j = 0; j < 4; ++j) { + // float q = db * grid[j] * signs[j]; + // sumqx += wk[j] * q * xk[j]; + // sumq2 += wk[j] * q * q; + //} } - if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2); + //if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2); } } - memcpy(y[ibl].qs, q3, 3*QK_K/8); + memcpy(qs, q3, quant_size); + + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; + } } @@ -10646,7 +10953,7 @@ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int nblock = n_per_row/QK_K; char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { - quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights); + quantize_row_iq3_xs_impl(256, src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq3_xxs); } @@ -10661,9 +10968,34 @@ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) { void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) { assert(k % QK_K == 0); - quantize_row_iq3_xxs_impl(x, y, k, NULL); + quantize_row_iq3_xs_impl(256, x, y, k, NULL); } +size_t quantize_iq3_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + quantize_row_iq3_xs_impl(512, src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_xs); + } + return nrow * nblock * sizeof(block_iq3_xs); +} + +void quantize_row_iq3_xs(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_iq3_xs * restrict y = vy; + quantize_row_iq3_xs_reference(x, y, k); +} + +void quantize_row_iq3_xs_reference(const float * restrict x, block_iq3_xs * restrict y, int k) { + assert(k % QK_K == 0); + quantize_row_iq3_xs_impl(512, x, y, k, NULL); +} + + // =================================== 1.5 bpw =================================================== static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, diff --git a/ggml-quants.h b/ggml-quants.h index 113623b62..9437cd69e 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -191,6 +191,14 @@ typedef struct { } block_iq3_xxs; static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); +// 3.3125 bpw +typedef struct { + ggml_fp16_t d; + uint8_t qs[3*QK_K/8]; + uint8_t qh[QK_K/32]; +} block_iq3_xs; +static_assert(sizeof(block_iq3_xs) == sizeof(ggml_fp16_t) + 3*(QK_K/8) + QK_K/32, "wrong iq3_xs block size/padding"); + typedef struct { ggml_fp16_t d; uint8_t qs[QK_K/8]; @@ -226,6 +234,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k); void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k); +void quantize_row_iq3_xs_reference (const float * GGML_RESTRICT x, block_iq3_xs * GGML_RESTRICT y, int k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); @@ -242,6 +251,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_iq3_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); @@ -262,6 +272,7 @@ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_iq3_xs (const block_iq3_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -280,6 +291,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") @@ -289,6 +301,7 @@ size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t quantize_iq3_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); diff --git a/ggml.c b/ggml.c index d710fe702..acd2eec1d 100644 --- a/ggml.c +++ b/ggml.c @@ -678,6 +678,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ3_XS] = { + .type_name = "iq3_xs", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_xs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_xs, + .from_float = quantize_row_iq3_xs, + .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xs_reference, + .vec_dot = ggml_vec_dot_iq3_xs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", .blck_size = QK_K, @@ -2304,6 +2316,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; + case GGML_FTYPE_MOSTLY_IQ3_XS: wtype = GGML_TYPE_IQ3_XS; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7738,6 +7751,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: { ggml_compute_forward_add_q_f32(params, dst); } break; @@ -8017,6 +8031,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: { ggml_compute_forward_add1_q_f32(params, dst); } break; @@ -8141,6 +8156,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: default: { GGML_ASSERT(false); @@ -11039,6 +11055,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: { ggml_compute_forward_out_prod_q_f32(params, dst); } break; @@ -11227,6 +11244,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: default: { GGML_ASSERT(false); @@ -11429,6 +11447,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: { ggml_compute_forward_get_rows_q(params, dst); } break; @@ -12129,6 +12148,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -12212,6 +12232,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ3_XS: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -19463,6 +19484,7 @@ void ggml_quantize_init(enum ggml_type type) { case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break; case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; + case GGML_TYPE_IQ3_XS: iq3xs_init_impl(512); break; default: // nothing break; } @@ -19737,6 +19759,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); GGML_ASSERT(result == row_size * nrows); } break; + case GGML_TYPE_IQ3_XS: + { + GGML_ASSERT(start % QK_K == 0); + GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = ggml_row_size(type, n_per_row); + result = quantize_iq3_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + GGML_ASSERT(result == row_size * nrows); + } break; case GGML_TYPE_IQ1_S: { GGML_ASSERT(start % QK_K == 0); diff --git a/ggml.h b/ggml.h index 37eff6279..6e8240650 100644 --- a/ggml.h +++ b/ggml.h @@ -350,6 +350,7 @@ extern "C" { GGML_TYPE_IQ3_XXS = 18, GGML_TYPE_IQ1_S = 19, GGML_TYPE_IQ4_NL = 20, + GGML_TYPE_IQ3_XS = 21, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -389,6 +390,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_XS = 20, // except 1d tensors }; // available tensor operations: diff --git a/llama.cpp b/llama.cpp index 37477e6ef..eaefbf724 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2545,6 +2545,7 @@ struct llama_model_loader { case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ3_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XS; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -2890,6 +2891,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3125 bpw"; default: return "unknown, may not work"; } @@ -10544,6 +10546,9 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_Q3_K : GGML_TYPE_IQ3_XXS; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q4_K; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; } @@ -10623,7 +10628,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (qs.model.hparams.n_expert == 8) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || - ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { + ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_Q5_K; } } else { @@ -10673,7 +10678,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || - new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { + new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || new_type == GGML_TYPE_IQ3_XS) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -10688,6 +10693,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XS: case GGML_TYPE_IQ1_S: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: new_type = GGML_TYPE_IQ4_NL; break; @@ -10733,6 +10739,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: quantized_type = GGML_TYPE_IQ3_XS; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } diff --git a/llama.h b/llama.h index 84f196b3b..7ad107bba 100644 --- a/llama.h +++ b/llama.h @@ -102,6 +102,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_XS = 26, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };