From 8e936ad0cd08b6783a65cfafaf033ca2a1195a08 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 18:30:56 +0300 Subject: [PATCH] ggml : adding Q5_0 mode --- ggml-cuda.cu | 42 ++++++++ ggml-cuda.h | 1 + ggml.c | 291 ++++++++++++++++++++++++++++++++++++++++++++++++++- ggml.h | 8 +- llama.cpp | 4 + llama.h | 1 + 6 files changed, 340 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6d1cc7008..b1bd29b10 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -37,6 +37,14 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + __half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + #define QK5_1 32 typedef struct { __half d; // delta @@ -147,6 +155,35 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } +static __global__ void dequantize_block_q5_0(const void * vx, float * y) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = ((vi & 0xf) | vh0); + const int8_t vi1 = ((vi >> 4) | vh1); + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + } +} + static __global__ void dequantize_block_q5_1(const void * vx, float * y) { const block_q5_1 * x = (const block_q5_1 *) vx; @@ -212,6 +249,11 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_3<<>>(vx, y); } +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_0; + dequantize_block_q5_0<<>>(vx, y); +} + void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_1; dequantize_block_q5_1<<>>(vx, y); diff --git a/ggml-cuda.h b/ggml-cuda.h index 348d9e907..ed9b44184 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -35,6 +35,7 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); diff --git a/ggml.c b/ggml.c index 91afe62ed..90eb48fd7 100644 --- a/ggml.c +++ b/ggml.c @@ -676,6 +676,14 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + #define QK5_1 32 typedef struct { ggml_fp16_t d; // delta @@ -1300,6 +1308,55 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int quantize_row_q4_3_reference(x, y, k); } +static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int l = 0; l < QK5_0; l++) { + const float v = x[i*QK5_0 + l]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int l = 0; l < QK5_0; l += 2) { + const float v0 = x[i*QK5_0 + l + 0]*id; + const float v1 = x[i*QK5_0 + l + 1]*id; + + const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f)); + const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f)); + + y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((vi0 & 0x10) >> 4) << (l + 0); + qh |= ((vi1 & 0x10) >> 4) << (l + 1); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_0 == 0); + + block_q5_0 * restrict y = vy; + + quantize_row_q5_0_reference(x, y, k); +} + static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { assert(k % QK5_1 == 0); const int nb = k / QK5_1; @@ -1861,6 +1918,42 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in } } +static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + const block_q5_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + // extract the 5-th bit from qh + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0x0F) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + + assert(!isnan(y[i*QK5_0 + l + 0])); + assert(!isnan(y[i*QK5_0 + l + 1])); + } + } +} + static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) { assert(k % QK5_1 == 0); const int nb = k / QK5_1; @@ -1918,6 +2011,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -1954,6 +2048,14 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q4_3_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, }, + [GGML_TYPE_Q5_0] = { + .dequantize_row_q = dequantize_row_q5_0, + .quantize_row_q = quantize_row_q5_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, [GGML_TYPE_Q5_1] = { .dequantize_row_q = dequantize_row_q5_1, .quantize_row_q = quantize_row_q5_1, @@ -3169,6 +3271,141 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * #endif } +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + assert(QK8_0 == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + uint64_t tmp[4]; + + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s16b = vdupq_n_s8(0x10); + + // extract the 5th bit + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b[(qh >> 24) ]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); + + // 4-bit -> 8-bit + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b)); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); + + // interleave + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add high bit and sub 16 + const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b); + const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b); + + // load y + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); + + const float x0d = GGML_FP16_TO_FP32(x0->d); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); +#endif + } + + *s = vaddvq_f32(sumv); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d)); + + __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + __m256i bx = _mm256_set_m128i(bx1, bx0); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8(8); + bx = _mm256_sub_epi8(bx, off); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#else + // scalar + float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[i].qs; + const int8_t * restrict y0 = y[i].qs; + + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + const float d = GGML_FP16_TO_FP32(x[i].d); + + int sxy = 0; + + for (int j = 0; j < QK8_0/2; j++) { + const uint8_t v0 = x0[j]; + + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; + + const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16; + const int x1_0 = ((v0 >> 4) | x1_0h) - 16; + + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; + + sxy += x0_0*y0_0 + x1_0*y1_0; + } + + sumf += (d*sxy)*y[i].d; + } + *s = sumf; +#endif +} + static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_1; @@ -3646,6 +3883,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, [GGML_TYPE_Q4_3] = QK4_3, + [GGML_TYPE_Q5_0] = QK5_0, [GGML_TYPE_Q5_1] = QK5_1, [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_Q8_1] = QK8_1, @@ -3653,7 +3891,7 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3662,6 +3900,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_Q4_3] = sizeof(block_q4_3), + [GGML_TYPE_Q5_0] = sizeof(block_q5_0), [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), @@ -3669,7 +3908,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3679,6 +3918,7 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", [GGML_TYPE_Q4_3] = "q4_3", + [GGML_TYPE_Q5_0] = "q5_0", [GGML_TYPE_Q5_1] = "q5_1", [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_Q8_1] = "q8_1", @@ -3686,7 +3926,7 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3695,6 +3935,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, [GGML_TYPE_Q4_3] = true, + [GGML_TYPE_Q5_0] = true, [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, @@ -3702,7 +3943,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 12, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -6923,6 +7164,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: { @@ -8412,6 +8654,9 @@ static void ggml_compute_forward_mul_mat_q_f32( else if (type == GGML_TYPE_Q4_3) { dequantize_row_q_cuda = dequantize_row_q4_3_cuda; } + else if (type == GGML_TYPE_Q5_0) { + dequantize_row_q_cuda = dequantize_row_q5_0_cuda; + } else if (type == GGML_TYPE_Q5_1) { dequantize_row_q_cuda = dequantize_row_q5_1_cuda; } @@ -8573,6 +8818,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: @@ -8804,6 +9050,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: @@ -12598,6 +12845,36 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_3*sizeof(block_q4_3)); } +size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int j = 0; j < n; j += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0; + + quantize_row_q5_0_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK5_0; l += 2) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK5_1 == 0); const int nb = k / QK5_1; @@ -12673,6 +12950,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; result = ggml_quantize_q4_3(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; case GGML_TYPE_Q5_1: { GGML_ASSERT(start % QK5_1 == 0); diff --git a/ggml.h b/ggml.h index 2784afc3d..d9d3d214e 100644 --- a/ggml.h +++ b/ggml.h @@ -222,9 +222,10 @@ extern "C" { GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, GGML_TYPE_Q4_3 = 5, - GGML_TYPE_Q5_1 = 6, - GGML_TYPE_Q8_0 = 7, - GGML_TYPE_Q8_1 = 8, + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -834,6 +835,7 @@ extern "C" { GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); diff --git a/llama.cpp b/llama.cpp index 9b167f971..2ae6cedb2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -484,6 +484,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; @@ -560,6 +561,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; @@ -852,6 +854,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; default: return "unknown, may not work"; @@ -1591,6 +1594,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; default: throw format("invalid output file type %d\n", ftype); diff --git a/llama.h b/llama.h index ef5e7a7f5..3b6c6cd62 100644 --- a/llama.h +++ b/llama.h @@ -75,6 +75,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors };