Weighted least squares

This commit is contained in:
Henrik Forstén 2023-12-28 15:55:12 +02:00
parent f478136773
commit 3fafecae9e

View file

@ -433,16 +433,19 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int
float xq = 0.0f;
float minx = x[0];
float maxx = x[0];
float q1 = 0.0f;
for (int i = 0; i < qk; i++) {
float qf = q[i];
qs += qf;
qs2 += qf*qf;
xs += x[i];
xq += x[i]*qf;
float w = fabsf(x[i]);
q1 += w;
qs += w*qf;
qs2 += w*qf*qf;
xs += w*x[i];
xq += w*x[i]*qf;
if (x[i] < minx) minx = x[i];
if (x[i] > maxx) maxx = x[i];
}
float denom = qs*qs - qs2*qk;
float denom = qs*qs - qs2*q1;
if (minx == maxx) {
*min = x[0];
*d = 0.0f;
@ -451,7 +454,7 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int
*d = 0.0f;
} else {
*min = (qs*xq - qs2*xs) / denom;
*d = (qs*xs - qk*xq) / denom;
*d = (qs*xs - q1*xq) / denom;
}
}
@ -491,11 +494,12 @@ static void lstsq_q_k(const float * restrict q, const float * restrict x, const
float sx = 0.0f;
float qx = 0.0f;
for (int i = 0; i < QK_K; i++) {
s2 += s[i/bs]*s[i/bs];
qs += q[i]*s[i/bs];
q2 += q[i]*q[i];
sx += s[i/bs]*x[i];
qx += q[i]*x[i];
float w = fabsf(x[i]);
s2 += w*s[i/bs]*s[i/bs];
qs += w*q[i]*s[i/bs];
q2 += w*q[i]*q[i];
sx += w*s[i/bs]*x[i];
qx += w*q[i]*x[i];
}
float denom = qs*qs - q2*s2;
if (s2 == 0.0f && q2 != 0.0f) {