Quantization loop

This commit is contained in:
Henrik Forstén 2023-12-27 09:44:38 +02:00
parent 0b6207ef61
commit d1af0a3a94

View file

@ -498,7 +498,7 @@ static void lstsq_q_k(const float * restrict q, const float * restrict x, const
qx += q[i]*x[i]; qx += q[i]*x[i];
} }
float denom = qs*qs - q2*s2; float denom = qs*qs - q2*s2;
if (s2 == 0.0f) { if (s2 == 0.0f && q2 != 0.0f) {
// All s are zero. // All s are zero.
*min = 0.0f; *min = 0.0f;
*d = qx / q2; *d = qx / q2;
@ -1859,8 +1859,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
} }
} }
redo_pos4:
for (int j = 0; j < QK_K/32; j++) { for (int j = 0; j < QK_K/32; j++) {
uint8_t q[QK_K/32]; uint8_t q[QK_K/32];
quantize_1(&x[32*j], 32, 4, q, &mins[j], &scales[j]); quantize_1(&x[32*j], 32, 4, q, &mins[j], &scales[j]);
@ -1883,45 +1881,57 @@ redo_pos4:
} }
} }
int ql_loop = 0;
quant_loop: ;
#if QK_K == 256 #if QK_K == 256
float inv_scale = max_scale == 0.0f ? 0.0f : 63.f/max_scale; float inv_scale = max_scale == 0.0f ? 0.0f : 63.f/max_scale;
float inv_min = max_min == 0.0f ? 0.0f : 63.f/max_min; float inv_min = max_min == 0.0f ? 0.0f : 63.f/max_min;
int all_zero_lm = 1;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
uint8_t ls = nearest_int(inv_scale*scales[j]); uint8_t ls = nearest_int(inv_scale*scales[j]);
uint8_t lm = nearest_int(inv_min*mins[j]); uint8_t lm = nearest_int(inv_min*mins[j]);
uint8_t best_lm = lm;
uint8_t best_ls = ls;
ls = MIN(63, ls); ls = MIN(63, ls);
lm = MIN(63, lm); lm = MIN(63, lm);
float rms = 0.0f; float best_rms = FLT_MAX;
float rms2 = 0.0f;
float rms3 = 0.0f;
uint8_t lm2 = MIN(63, MAX(0, lm - 1));
uint8_t lm3 = MIN(63, MAX(0, lm + 1));
const float d1 = max_scale / 63.0f; const float d1 = max_scale / 63.0f;
const float dmin1 = max_min / 63.0f; const float dmin1 = max_min / 63.0f;
for (int ii = 0; ii < 32; ii++) { int limit = 1;
const float d = d1 * ls; if (ql_loop) limit = 4;
const float dm1 = dmin1 * lm; for (int lst = MAX(0, ls-limit); lst <= MIN(63, ls+limit); lst++) {
const float dm2 = dmin1 * lm2; for (int lmt = MAX(0, lm-limit); lmt <= MIN(63, lm+limit); lmt++) {
const float dm3 = dmin1 * lm3; float rms = 0.0f;
if (!d) continue; for (int ii = 0; ii < 32; ii++) {
int l1 = nearest_int((x[32*j + ii] + dm1)/d); const float d = d1 * lst;
l1 = MAX(0, MIN(15, l1)); const float dm1 = dmin1 * lmt;
int l2 = nearest_int((x[32*j + ii] + dm2)/d); int l1 = 0;
l2 = MAX(0, MIN(15, l2)); if (d) {
int l3 = nearest_int((x[32*j + ii] + dm3)/d); l1 = nearest_int((x[32*j + ii] + dm1)/d);
l3 = MAX(0, MIN(15, l3)); l1 = MAX(0, MIN(15, l1));
rms += ((d*l1 - dm1) - x[32*j + ii]) * ((d*l1 - dm1) - x[32*j + ii]); }
rms2 += ((d*l2 - dm2) - x[32*j + ii]) * ((d*l2 - dm2) - x[32*j + ii]); float e = ((d*l1 - dm1) - x[32*j + ii]);
rms3 += ((d*l3 - dm3) - x[32*j + ii]) * ((d*l3 - dm3) - x[32*j + ii]); rms += e*e;
}
if (rms < best_rms) {
best_lm = lmt;
best_ls = lst;
best_rms = rms;
}
}
} }
if (rms2 < rms && rms2 < rms3) { //if (lm != best_lm) {
lm = lm2; // printf("best %d, orig %d\n", best_lm, lm);
} //}
if (rms3 < rms && rms3 < rms2) { lm = best_lm;
lm = lm3; ls = best_ls;
} //if (rms2 < rms && rms2 < rms3) {
if (lm != 0) all_zero_lm = 0; // printf("rms2 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3);
// lm = lm2;
//}
//if (rms3 < rms && rms3 < rms2) {
// printf("rms3 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3);
// lm = lm3;
//}
if (j < 4) { if (j < 4) {
y[i].scales[j] = ls; y[i].scales[j] = ls;
y[i].scales[j+4] = lm; y[i].scales[j+4] = lm;
@ -1932,13 +1942,14 @@ redo_pos4:
} }
} }
if (all_zero_lm && !all_positive) { //if (all_zero_lm && !all_positive && !ql_loop) {
all_positive = 1; // all_positive = 1;
//printf("**********red_pos4\n"); // //printf("**********red_pos4\n");
goto redo_pos4; // goto redo_pos4;
} else if (all_zero_lm) { //}
//printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min); //} else if (all_zero_lm) {
} // //printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min);
//}
y[i].d = GGML_FP32_TO_FP16(max_scale/63.f); y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f); y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
@ -1962,12 +1973,26 @@ redo_pos4:
} }
} }
float min; //printf("%d orig: %f %f, ", ql_loop, max_min, max_scale);
float d; float min, scale;
lstsq_q_k(q_fit, x, q_m, 32, &min, &d); lstsq_q_k(q_fit, x, q_m, 32, &min, &scale);
y[i].d = GGML_FP32_TO_FP16(d); if (min != min) {
printf("min nan\n");
}
if (scale != scale) {
printf("scale nan\n");
}
//printf("fit: %f %f\n", max_min, max_scale);
y[i].d = GGML_FP32_TO_FP16(scale);
y[i].dmin = GGML_FP32_TO_FP16(min); y[i].dmin = GGML_FP32_TO_FP16(min);
//printf("%f %f, %f %f\n", max_min, min * 63.0f, max_scale, scale * 63.0f);
max_scale = GGML_FP16_TO_FP32(y[i].d) * 63.0f;
max_min = GGML_FP16_TO_FP32(y[i].dmin) * 63.0f;
ql_loop++;
if (ql_loop == 1) {
goto quant_loop;
}
#else #else
const float s_factor = 15.f; const float s_factor = 15.f;
float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;