Quantization loop

This commit is contained in:
Henrik Forstén 2023-12-27 09:44:38 +02:00
parent 0b6207ef61
commit d1af0a3a94

View file

@ -498,7 +498,7 @@ static void lstsq_q_k(const float * restrict q, const float * restrict x, const
qx += q[i]*x[i];
}
float denom = qs*qs - q2*s2;
if (s2 == 0.0f) {
if (s2 == 0.0f && q2 != 0.0f) {
// All s are zero.
*min = 0.0f;
*d = qx / q2;
@ -1859,8 +1859,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
}
}
redo_pos4:
for (int j = 0; j < QK_K/32; j++) {
uint8_t q[QK_K/32];
quantize_1(&x[32*j], 32, 4, q, &mins[j], &scales[j]);
@ -1883,45 +1881,57 @@ redo_pos4:
}
}
int ql_loop = 0;
quant_loop: ;
#if QK_K == 256
float inv_scale = max_scale == 0.0f ? 0.0f : 63.f/max_scale;
float inv_min = max_min == 0.0f ? 0.0f : 63.f/max_min;
int all_zero_lm = 1;
for (int j = 0; j < QK_K/32; ++j) {
uint8_t ls = nearest_int(inv_scale*scales[j]);
uint8_t lm = nearest_int(inv_min*mins[j]);
uint8_t best_lm = lm;
uint8_t best_ls = ls;
ls = MIN(63, ls);
lm = MIN(63, lm);
float rms = 0.0f;
float rms2 = 0.0f;
float rms3 = 0.0f;
uint8_t lm2 = MIN(63, MAX(0, lm - 1));
uint8_t lm3 = MIN(63, MAX(0, lm + 1));
float best_rms = FLT_MAX;
const float d1 = max_scale / 63.0f;
const float dmin1 = max_min / 63.0f;
for (int ii = 0; ii < 32; ii++) {
const float d = d1 * ls;
const float dm1 = dmin1 * lm;
const float dm2 = dmin1 * lm2;
const float dm3 = dmin1 * lm3;
if (!d) continue;
int l1 = nearest_int((x[32*j + ii] + dm1)/d);
l1 = MAX(0, MIN(15, l1));
int l2 = nearest_int((x[32*j + ii] + dm2)/d);
l2 = MAX(0, MIN(15, l2));
int l3 = nearest_int((x[32*j + ii] + dm3)/d);
l3 = MAX(0, MIN(15, l3));
rms += ((d*l1 - dm1) - x[32*j + ii]) * ((d*l1 - dm1) - x[32*j + ii]);
rms2 += ((d*l2 - dm2) - x[32*j + ii]) * ((d*l2 - dm2) - x[32*j + ii]);
rms3 += ((d*l3 - dm3) - x[32*j + ii]) * ((d*l3 - dm3) - x[32*j + ii]);
int limit = 1;
if (ql_loop) limit = 4;
for (int lst = MAX(0, ls-limit); lst <= MIN(63, ls+limit); lst++) {
for (int lmt = MAX(0, lm-limit); lmt <= MIN(63, lm+limit); lmt++) {
float rms = 0.0f;
for (int ii = 0; ii < 32; ii++) {
const float d = d1 * lst;
const float dm1 = dmin1 * lmt;
int l1 = 0;
if (d) {
l1 = nearest_int((x[32*j + ii] + dm1)/d);
l1 = MAX(0, MIN(15, l1));
}
float e = ((d*l1 - dm1) - x[32*j + ii]);
rms += e*e;
}
if (rms < best_rms) {
best_lm = lmt;
best_ls = lst;
best_rms = rms;
}
}
}
if (rms2 < rms && rms2 < rms3) {
lm = lm2;
}
if (rms3 < rms && rms3 < rms2) {
lm = lm3;
}
if (lm != 0) all_zero_lm = 0;
//if (lm != best_lm) {
// printf("best %d, orig %d\n", best_lm, lm);
//}
lm = best_lm;
ls = best_ls;
//if (rms2 < rms && rms2 < rms3) {
// printf("rms2 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3);
// lm = lm2;
//}
//if (rms3 < rms && rms3 < rms2) {
// printf("rms3 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3);
// lm = lm3;
//}
if (j < 4) {
y[i].scales[j] = ls;
y[i].scales[j+4] = lm;
@ -1932,13 +1942,14 @@ redo_pos4:
}
}
if (all_zero_lm && !all_positive) {
all_positive = 1;
//printf("**********red_pos4\n");
goto redo_pos4;
} else if (all_zero_lm) {
//printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min);
}
//if (all_zero_lm && !all_positive && !ql_loop) {
// all_positive = 1;
// //printf("**********red_pos4\n");
// goto redo_pos4;
//}
//} else if (all_zero_lm) {
// //printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min);
//}
y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
@ -1962,12 +1973,26 @@ redo_pos4:
}
}
float min;
float d;
lstsq_q_k(q_fit, x, q_m, 32, &min, &d);
y[i].d = GGML_FP32_TO_FP16(d);
//printf("%d orig: %f %f, ", ql_loop, max_min, max_scale);
float min, scale;
lstsq_q_k(q_fit, x, q_m, 32, &min, &scale);
if (min != min) {
printf("min nan\n");
}
if (scale != scale) {
printf("scale nan\n");
}
//printf("fit: %f %f\n", max_min, max_scale);
y[i].d = GGML_FP32_TO_FP16(scale);
y[i].dmin = GGML_FP32_TO_FP16(min);
//printf("%f %f, %f %f\n", max_min, min * 63.0f, max_scale, scale * 63.0f);
max_scale = GGML_FP16_TO_FP32(y[i].d) * 63.0f;
max_min = GGML_FP16_TO_FP32(y[i].dmin) * 63.0f;
ql_loop++;
if (ql_loop == 1) {
goto quant_loop;
}
#else
const float s_factor = 15.f;
float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;