From 3fafecae9eb642ab872ea95c63b53e0a240d9223 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Thu, 28 Dec 2023 15:55:12 +0200 Subject: [PATCH] Weighted least squares --- ggml-quants.c | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index 9c1e3e0e4..a2054acd9 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -433,16 +433,19 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int float xq = 0.0f; float minx = x[0]; float maxx = x[0]; + float q1 = 0.0f; for (int i = 0; i < qk; i++) { float qf = q[i]; - qs += qf; - qs2 += qf*qf; - xs += x[i]; - xq += x[i]*qf; + float w = fabsf(x[i]); + q1 += w; + qs += w*qf; + qs2 += w*qf*qf; + xs += w*x[i]; + xq += w*x[i]*qf; if (x[i] < minx) minx = x[i]; if (x[i] > maxx) maxx = x[i]; } - float denom = qs*qs - qs2*qk; + float denom = qs*qs - qs2*q1; if (minx == maxx) { *min = x[0]; *d = 0.0f; @@ -451,7 +454,7 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int *d = 0.0f; } else { *min = (qs*xq - qs2*xs) / denom; - *d = (qs*xs - qk*xq) / denom; + *d = (qs*xs - q1*xq) / denom; } } @@ -491,11 +494,12 @@ static void lstsq_q_k(const float * restrict q, const float * restrict x, const float sx = 0.0f; float qx = 0.0f; for (int i = 0; i < QK_K; i++) { - s2 += s[i/bs]*s[i/bs]; - qs += q[i]*s[i/bs]; - q2 += q[i]*q[i]; - sx += s[i/bs]*x[i]; - qx += q[i]*x[i]; + float w = fabsf(x[i]); + s2 += w*s[i/bs]*s[i/bs]; + qs += w*q[i]*s[i/bs]; + q2 += w*q[i]*q[i]; + sx += w*s[i/bs]*x[i]; + qx += w*q[i]*x[i]; } float denom = qs*qs - q2*s2; if (s2 == 0.0f && q2 != 0.0f) {