From 4f8dcb16533db2686c716e651156f8d90c4edd5f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 14 Aug 2023 16:06:00 +0300 Subject: [PATCH] Adding make_qkx2_quants With it, we get PPL = 5.8828 for L2-7B Q4_K_S. --- k_quants.c | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- llama.cpp | 2 +- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/k_quants.c b/k_quants.c index 4e5563a33..b5faefd6c 100644 --- a/k_quants.c +++ b/k_quants.c @@ -225,9 +225,13 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t int ntry, float alpha) { float min = x[0]; float max = x[0]; + float sum_x = 0; + float sum_x2 = 0; for (int i = 1; i < n; ++i) { if (x[i] < min) min = x[i]; if (x[i] > max) max = x[i]; + sum_x += x[i]; + sum_x2 += x[i]*x[i]; } if (max == min) { for (int i = 0; i < n; ++i) L[i] = 0; @@ -264,6 +268,76 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t return scale; } +static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, + uint8_t * restrict Laux) { + float min = x[0]; + float max = x[0]; + float sum_x = 0, sum_x2 = 0; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + sum_x += x[i]; + sum_x2 += x[i] * x[i]; + } + if (min > 0) min = 0; + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = -min; + return 0.f; + } + float num = sum_x2 * n - sum_x * sum_x * n / (n-1); + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mse = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + float w = x[i] * x[i]; + best_mse += w * diff * diff; + } + if (num <= 0) { + *the_min = -min; + return scale; + } + for (int is = -5; is <= 10; ++is) { + iscale = (0.1f*is + nmax)/(max - min); + int sum_l = 0, sum_l2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + sum_l += l; + sum_l2 += l*l; + } + int den = sum_l2 * n - sum_l * sum_l; + if (den > 0) { + float this_scale = sqrtf(num / den); + float this_min = (sum_x - this_scale * sum_l)/n; + if (this_min > 0) { + this_min = 0; + this_scale = sqrtf(sum_x2 / sum_l2); + } + float mse = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + float w = x[i] * x[i]; + mse += w * diff * diff; + } + if (mse < best_mse) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mse = mse; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { if (j < 4) { @@ -282,6 +356,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict const int nb = k / QK_K; uint8_t L[QK_K]; + uint8_t Laux[16]; float mins[QK_K/16]; float scales[QK_K/16]; @@ -292,7 +367,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f); + //scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f); + scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux); float scale = scales[j]; if (scale > max_scale) { max_scale = scale; @@ -638,6 +714,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict const int nb = k / QK_K; uint8_t L[QK_K]; + uint8_t Laux[32]; float mins[QK_K/32]; float scales[QK_K/32]; @@ -646,7 +723,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { - scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux); float scale = scales[j]; if (scale > max_scale) { max_scale = scale; @@ -797,6 +875,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict #if QK_K == 256 uint8_t L[QK_K]; + //uint8_t Laux[32]; float mins[QK_K/32]; float scales[QK_K/32]; #else @@ -812,6 +891,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + //scales[j] = make_qkx2_quants(32, 31, x + 32*j, L + 32*j, &mins[j], Laux); float scale = scales[j]; if (scale > max_scale) { max_scale = scale; diff --git a/llama.cpp b/llama.cpp index 07d8f343f..49e4503c0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3740,7 +3740,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K; + //else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K; ++i_feed_forward_w2; } else if (name.find("attn_output.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;