Weighted least squares
This commit is contained in:
parent
f478136773
commit
3fafecae9e
1 changed files with 15 additions and 11 deletions
|
@ -433,16 +433,19 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int
|
|||
float xq = 0.0f;
|
||||
float minx = x[0];
|
||||
float maxx = x[0];
|
||||
float q1 = 0.0f;
|
||||
for (int i = 0; i < qk; i++) {
|
||||
float qf = q[i];
|
||||
qs += qf;
|
||||
qs2 += qf*qf;
|
||||
xs += x[i];
|
||||
xq += x[i]*qf;
|
||||
float w = fabsf(x[i]);
|
||||
q1 += w;
|
||||
qs += w*qf;
|
||||
qs2 += w*qf*qf;
|
||||
xs += w*x[i];
|
||||
xq += w*x[i]*qf;
|
||||
if (x[i] < minx) minx = x[i];
|
||||
if (x[i] > maxx) maxx = x[i];
|
||||
}
|
||||
float denom = qs*qs - qs2*qk;
|
||||
float denom = qs*qs - qs2*q1;
|
||||
if (minx == maxx) {
|
||||
*min = x[0];
|
||||
*d = 0.0f;
|
||||
|
@ -451,7 +454,7 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int
|
|||
*d = 0.0f;
|
||||
} else {
|
||||
*min = (qs*xq - qs2*xs) / denom;
|
||||
*d = (qs*xs - qk*xq) / denom;
|
||||
*d = (qs*xs - q1*xq) / denom;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -491,11 +494,12 @@ static void lstsq_q_k(const float * restrict q, const float * restrict x, const
|
|||
float sx = 0.0f;
|
||||
float qx = 0.0f;
|
||||
for (int i = 0; i < QK_K; i++) {
|
||||
s2 += s[i/bs]*s[i/bs];
|
||||
qs += q[i]*s[i/bs];
|
||||
q2 += q[i]*q[i];
|
||||
sx += s[i/bs]*x[i];
|
||||
qx += q[i]*x[i];
|
||||
float w = fabsf(x[i]);
|
||||
s2 += w*s[i/bs]*s[i/bs];
|
||||
qs += w*q[i]*s[i/bs];
|
||||
q2 += w*q[i]*q[i];
|
||||
sx += w*s[i/bs]*x[i];
|
||||
qx += w*q[i]*x[i];
|
||||
}
|
||||
float denom = qs*qs - q2*s2;
|
||||
if (s2 == 0.0f && q2 != 0.0f) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue