From d1af0a3a945fd8300718ddacea786508376831e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Wed, 27 Dec 2023 09:44:38 +0200 Subject: [PATCH] Quantization loop --- ggml-quants.c | 109 +++++++++++++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 42 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index 06481048f..0c3b0d42d 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -498,7 +498,7 @@ static void lstsq_q_k(const float * restrict q, const float * restrict x, const qx += q[i]*x[i]; } float denom = qs*qs - q2*s2; - if (s2 == 0.0f) { + if (s2 == 0.0f && q2 != 0.0f) { // All s are zero. *min = 0.0f; *d = qx / q2; @@ -1859,8 +1859,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } -redo_pos4: - for (int j = 0; j < QK_K/32; j++) { uint8_t q[QK_K/32]; quantize_1(&x[32*j], 32, 4, q, &mins[j], &scales[j]); @@ -1883,45 +1881,57 @@ redo_pos4: } } + int ql_loop = 0; +quant_loop: ; #if QK_K == 256 float inv_scale = max_scale == 0.0f ? 0.0f : 63.f/max_scale; float inv_min = max_min == 0.0f ? 0.0f : 63.f/max_min; - int all_zero_lm = 1; for (int j = 0; j < QK_K/32; ++j) { uint8_t ls = nearest_int(inv_scale*scales[j]); uint8_t lm = nearest_int(inv_min*mins[j]); + uint8_t best_lm = lm; + uint8_t best_ls = ls; ls = MIN(63, ls); lm = MIN(63, lm); - float rms = 0.0f; - float rms2 = 0.0f; - float rms3 = 0.0f; - uint8_t lm2 = MIN(63, MAX(0, lm - 1)); - uint8_t lm3 = MIN(63, MAX(0, lm + 1)); + float best_rms = FLT_MAX; const float d1 = max_scale / 63.0f; const float dmin1 = max_min / 63.0f; - for (int ii = 0; ii < 32; ii++) { - const float d = d1 * ls; - const float dm1 = dmin1 * lm; - const float dm2 = dmin1 * lm2; - const float dm3 = dmin1 * lm3; - if (!d) continue; - int l1 = nearest_int((x[32*j + ii] + dm1)/d); - l1 = MAX(0, MIN(15, l1)); - int l2 = nearest_int((x[32*j + ii] + dm2)/d); - l2 = MAX(0, MIN(15, l2)); - int l3 = nearest_int((x[32*j + ii] + dm3)/d); - l3 = MAX(0, MIN(15, l3)); - rms += ((d*l1 - dm1) - x[32*j + ii]) * ((d*l1 - dm1) - x[32*j + ii]); - rms2 += ((d*l2 - dm2) - x[32*j + ii]) * ((d*l2 - dm2) - x[32*j + ii]); - rms3 += ((d*l3 - dm3) - x[32*j + ii]) * ((d*l3 - dm3) - x[32*j + ii]); + int limit = 1; + if (ql_loop) limit = 4; + for (int lst = MAX(0, ls-limit); lst <= MIN(63, ls+limit); lst++) { + for (int lmt = MAX(0, lm-limit); lmt <= MIN(63, lm+limit); lmt++) { + float rms = 0.0f; + for (int ii = 0; ii < 32; ii++) { + const float d = d1 * lst; + const float dm1 = dmin1 * lmt; + int l1 = 0; + if (d) { + l1 = nearest_int((x[32*j + ii] + dm1)/d); + l1 = MAX(0, MIN(15, l1)); + } + float e = ((d*l1 - dm1) - x[32*j + ii]); + rms += e*e; + } + if (rms < best_rms) { + best_lm = lmt; + best_ls = lst; + best_rms = rms; + } + } } - if (rms2 < rms && rms2 < rms3) { - lm = lm2; - } - if (rms3 < rms && rms3 < rms2) { - lm = lm3; - } - if (lm != 0) all_zero_lm = 0; + //if (lm != best_lm) { + // printf("best %d, orig %d\n", best_lm, lm); + //} + lm = best_lm; + ls = best_ls; + //if (rms2 < rms && rms2 < rms3) { + // printf("rms2 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3); + // lm = lm2; + //} + //if (rms3 < rms && rms3 < rms2) { + // printf("rms3 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3); + // lm = lm3; + //} if (j < 4) { y[i].scales[j] = ls; y[i].scales[j+4] = lm; @@ -1932,13 +1942,14 @@ redo_pos4: } } - if (all_zero_lm && !all_positive) { - all_positive = 1; - //printf("**********red_pos4\n"); - goto redo_pos4; - } else if (all_zero_lm) { - //printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min); - } + //if (all_zero_lm && !all_positive && !ql_loop) { + // all_positive = 1; + // //printf("**********red_pos4\n"); + // goto redo_pos4; + //} + //} else if (all_zero_lm) { + // //printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min); + //} y[i].d = GGML_FP32_TO_FP16(max_scale/63.f); y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f); @@ -1962,12 +1973,26 @@ redo_pos4: } } - float min; - float d; - lstsq_q_k(q_fit, x, q_m, 32, &min, &d); - y[i].d = GGML_FP32_TO_FP16(d); + //printf("%d orig: %f %f, ", ql_loop, max_min, max_scale); + float min, scale; + lstsq_q_k(q_fit, x, q_m, 32, &min, &scale); + if (min != min) { + printf("min nan\n"); + } + if (scale != scale) { + printf("scale nan\n"); + } + //printf("fit: %f %f\n", max_min, max_scale); + y[i].d = GGML_FP32_TO_FP16(scale); y[i].dmin = GGML_FP32_TO_FP16(min); + //printf("%f %f, %f %f\n", max_min, min * 63.0f, max_scale, scale * 63.0f); + max_scale = GGML_FP16_TO_FP32(y[i].d) * 63.0f; + max_min = GGML_FP16_TO_FP32(y[i].dmin) * 63.0f; + ql_loop++; + if (ql_loop == 1) { + goto quant_loop; + } #else const float s_factor = 15.f; float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;