From 2b097682e07f4c7f830f197d1b803e7155a632a1 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 20 Jun 2024 20:06:13 +0800 Subject: [PATCH] remove q2_2 --- examples/quantize/quantize.cpp | 1 - ggml-common.h | 6 -- ggml-quants.c | 127 --------------------------------- ggml-quants.h | 5 -- ggml.c | 14 ---- ggml.h | 1 - gguf-py/gguf/constants.py | 3 - llama.cpp | 2 - llama.h | 1 - 9 files changed, 160 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 05df330c0..28584e14b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,6 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-common.h b/ggml-common.h index a1a824665..d3fa51235 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -137,12 +137,6 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP -#define QK2_2 32 -typedef struct { - uint8_t qs[QK2_2 / 4]; // nibbles / quants -} block_q2_2; -static_assert(sizeof(block_q2_2) == QK2_2 / 4, "wrong q2_2 block size/padding"); - #define QK4_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml-quants.c b/ggml-quants.c index f45ece1f2..ee86fd6b9 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,40 +659,6 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx -void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) { - static const int qk = QK2_2; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - - for (int j = 0; j < qk/4; ++j) { - int8_t x0 = (int8_t)x[i*qk + j*4 + 0]; - int8_t x1 = (int8_t)x[i*qk + j*4 + 1]; - int8_t x2 = (int8_t)x[i*qk + j*4 + 2]; - int8_t x3 = (int8_t)x[i*qk + j*4 + 3]; - - const uint8_t xi0 = x0 >= 0 ? x0 : 3; - const uint8_t xi1 = x1 >= 0 ? x1 : 3; - const uint8_t xi2 = x2 >= 0 ? x2 : 3; - const uint8_t xi3 = x3 >= 0 ? x3 : 3; - - y[i].qs[j] = 0; - y[i].qs[j] |= (xi0 << 6); - y[i].qs[j] |= (xi1 << 4); - y[i].qs[j] |= (xi2 << 2); - y[i].qs[j] |= (xi3 << 0); - } - } -} - -// reference implementation for deterministic creation of model files -void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q2_2_reference(x, y, k); -} - void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -1545,26 +1511,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #endif } -void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int64_t k) { - static const int qk = QK2_2; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - - for (int j = 0; j < qk/4; ++j) { - const int8_t * q = (const int8_t *) (q22_grid + x[i].qs[j]); - - *y++ = (float) q[0]; - *y++ = (float) q[1]; - *y++ = (float) q[2]; - *y++ = (float) q[3]; - } - } -} - void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3359,13 +3305,6 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } -size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - (void)quant_weights; // not used - const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row); - quantize_row_q2_2_reference(src, dst, (int64_t)nrow*n_per_row); - return nrow * row_size; -} - // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { @@ -3786,71 +3725,6 @@ static inline __m128i get_scale_shuffle(int i) { } #endif -void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q2_2 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__AVX2__) - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) ); - - __m128i xq8b = _mm_loadu_si64(x[i].qs); - __m256i xq8 = MM256_SET_M128I(xq8b, xq8b); - __m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, - 4, -1, 4, -1, 4, -1, 4, -1, - 1, -1, 1, -1, 1, -1, 1, -1, - 0, -1, 0, -1, 0, -1, 0, -1)); - __m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, - 6, -1, 6, -1, 6, -1, 6, -1, - 3, -1, 3, -1, 3, -1, 3, -1, - 2, -1, 2, -1, 2, -1, 2, -1)); - __m256i shift = _mm256_set_epi16(64, 16, 4, 1, - 64, 16, 4, 1, - 64, 16, 4, 1, - 64, 16, 4, 1); - xq8l = _mm256_mullo_epi16(xq8l, shift); - xq8h = _mm256_mullo_epi16(xq8h, shift); - xq8l = _mm256_srai_epi16(xq8l, 14); - xq8h = _mm256_srai_epi16(xq8h, 14); - xq8 = _mm256_packs_epi16(xq8l, xq8h); - - __m256i yq8 = _mm256_lddqu_si256((const __m256i*)(y[i].qs)); - const __m256 q = mul_sum_i8_pairs_float(xq8, yq8); - - acc = _mm256_fmadd_ps( d, q, acc ); - } - - *s = hsum_float_8(acc); -#else - - float sumf = 0.0; - for (int i = 0; i < nb; i++) { - int sumi = 0; - for (int j = 0; j < qk / 4; j++) { - const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]); - sumi += (int)y[i].qs[4*j+0] * weight[0]; - sumi += (int)y[i].qs[4*j+1] * weight[1]; - sumi += (int)y[i].qs[4*j+2] * weight[2]; - sumi += (int)y[i].qs[4*j+3] * weight[3]; - } - sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d)); - } - *s = sumf; -#endif -} - void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -14488,7 +14362,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; - case GGML_TYPE_Q2_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml-quants.h b/ggml-quants.h index e159cef5f..4d436a8f0 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -12,7 +12,6 @@ extern "C" { #endif // Quantization -void quantize_row_q2_2_reference(const float * GGML_RESTRICT x, block_q2_2 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -33,7 +32,6 @@ void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -55,7 +53,6 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -81,7 +78,6 @@ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product -void ggml_vec_dot_q2_2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -120,7 +116,6 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q2_2(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml.c b/ggml.c index 303b2f563..1fc77743b 100644 --- a/ggml.c +++ b/ggml.c @@ -616,18 +616,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, - [GGML_TYPE_Q2_2] = { - .type_name = "q2_2", - .blck_size = QK2_2, - .type_size = sizeof(block_q2_2), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q2_2, - .from_float = quantize_row_q2_2, - .from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference, - .vec_dot = ggml_vec_dot_q2_2_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -14169,7 +14157,6 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: - case GGML_TYPE_Q2_2: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -21319,7 +21306,6 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml.h b/ggml.h index 4ec555ccb..13502a362 100644 --- a/ggml.h +++ b/ggml.h @@ -377,7 +377,6 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, - GGML_TYPE_Q2_2 = 31, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 301200869..1fc8fcde5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -923,7 +923,6 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 - Q2_2 = 31 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -965,7 +964,6 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_Q2_2 = 33 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1012,7 +1010,6 @@ QK_K = 256 GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F32: (1, 4), GGMLQuantizationType.F16: (1, 2), - GGMLQuantizationType.Q2_2: (32, 8), GGMLQuantizationType.Q4_0: (32, 2 + 16), GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), diff --git a/llama.cpp b/llama.cpp index 85182f4bb..1622bc8d3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3885,7 +3885,6 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: @@ -15462,7 +15461,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_ftype ftype = params->ftype; switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break; case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; diff --git a/llama.h b/llama.h index 7a2e0e31c..62908261f 100644 --- a/llama.h +++ b/llama.h @@ -156,7 +156,6 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_2 = 33, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };