Quantization loop
This commit is contained in:
parent
0b6207ef61
commit
d1af0a3a94
1 changed files with 67 additions and 42 deletions
109
ggml-quants.c
109
ggml-quants.c
|
@ -498,7 +498,7 @@ static void lstsq_q_k(const float * restrict q, const float * restrict x, const
|
|||
qx += q[i]*x[i];
|
||||
}
|
||||
float denom = qs*qs - q2*s2;
|
||||
if (s2 == 0.0f) {
|
||||
if (s2 == 0.0f && q2 != 0.0f) {
|
||||
// All s are zero.
|
||||
*min = 0.0f;
|
||||
*d = qx / q2;
|
||||
|
@ -1859,8 +1859,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|||
}
|
||||
}
|
||||
|
||||
redo_pos4:
|
||||
|
||||
for (int j = 0; j < QK_K/32; j++) {
|
||||
uint8_t q[QK_K/32];
|
||||
quantize_1(&x[32*j], 32, 4, q, &mins[j], &scales[j]);
|
||||
|
@ -1883,45 +1881,57 @@ redo_pos4:
|
|||
}
|
||||
}
|
||||
|
||||
int ql_loop = 0;
|
||||
quant_loop: ;
|
||||
#if QK_K == 256
|
||||
float inv_scale = max_scale == 0.0f ? 0.0f : 63.f/max_scale;
|
||||
float inv_min = max_min == 0.0f ? 0.0f : 63.f/max_min;
|
||||
int all_zero_lm = 1;
|
||||
for (int j = 0; j < QK_K/32; ++j) {
|
||||
uint8_t ls = nearest_int(inv_scale*scales[j]);
|
||||
uint8_t lm = nearest_int(inv_min*mins[j]);
|
||||
uint8_t best_lm = lm;
|
||||
uint8_t best_ls = ls;
|
||||
ls = MIN(63, ls);
|
||||
lm = MIN(63, lm);
|
||||
float rms = 0.0f;
|
||||
float rms2 = 0.0f;
|
||||
float rms3 = 0.0f;
|
||||
uint8_t lm2 = MIN(63, MAX(0, lm - 1));
|
||||
uint8_t lm3 = MIN(63, MAX(0, lm + 1));
|
||||
float best_rms = FLT_MAX;
|
||||
const float d1 = max_scale / 63.0f;
|
||||
const float dmin1 = max_min / 63.0f;
|
||||
for (int ii = 0; ii < 32; ii++) {
|
||||
const float d = d1 * ls;
|
||||
const float dm1 = dmin1 * lm;
|
||||
const float dm2 = dmin1 * lm2;
|
||||
const float dm3 = dmin1 * lm3;
|
||||
if (!d) continue;
|
||||
int l1 = nearest_int((x[32*j + ii] + dm1)/d);
|
||||
l1 = MAX(0, MIN(15, l1));
|
||||
int l2 = nearest_int((x[32*j + ii] + dm2)/d);
|
||||
l2 = MAX(0, MIN(15, l2));
|
||||
int l3 = nearest_int((x[32*j + ii] + dm3)/d);
|
||||
l3 = MAX(0, MIN(15, l3));
|
||||
rms += ((d*l1 - dm1) - x[32*j + ii]) * ((d*l1 - dm1) - x[32*j + ii]);
|
||||
rms2 += ((d*l2 - dm2) - x[32*j + ii]) * ((d*l2 - dm2) - x[32*j + ii]);
|
||||
rms3 += ((d*l3 - dm3) - x[32*j + ii]) * ((d*l3 - dm3) - x[32*j + ii]);
|
||||
int limit = 1;
|
||||
if (ql_loop) limit = 4;
|
||||
for (int lst = MAX(0, ls-limit); lst <= MIN(63, ls+limit); lst++) {
|
||||
for (int lmt = MAX(0, lm-limit); lmt <= MIN(63, lm+limit); lmt++) {
|
||||
float rms = 0.0f;
|
||||
for (int ii = 0; ii < 32; ii++) {
|
||||
const float d = d1 * lst;
|
||||
const float dm1 = dmin1 * lmt;
|
||||
int l1 = 0;
|
||||
if (d) {
|
||||
l1 = nearest_int((x[32*j + ii] + dm1)/d);
|
||||
l1 = MAX(0, MIN(15, l1));
|
||||
}
|
||||
float e = ((d*l1 - dm1) - x[32*j + ii]);
|
||||
rms += e*e;
|
||||
}
|
||||
if (rms < best_rms) {
|
||||
best_lm = lmt;
|
||||
best_ls = lst;
|
||||
best_rms = rms;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (rms2 < rms && rms2 < rms3) {
|
||||
lm = lm2;
|
||||
}
|
||||
if (rms3 < rms && rms3 < rms2) {
|
||||
lm = lm3;
|
||||
}
|
||||
if (lm != 0) all_zero_lm = 0;
|
||||
//if (lm != best_lm) {
|
||||
// printf("best %d, orig %d\n", best_lm, lm);
|
||||
//}
|
||||
lm = best_lm;
|
||||
ls = best_ls;
|
||||
//if (rms2 < rms && rms2 < rms3) {
|
||||
// printf("rms2 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3);
|
||||
// lm = lm2;
|
||||
//}
|
||||
//if (rms3 < rms && rms3 < rms2) {
|
||||
// printf("rms3 %f %f %f, lm %d %d %d\n", rms, rms2, rms3, lm, lm2, lm3);
|
||||
// lm = lm3;
|
||||
//}
|
||||
if (j < 4) {
|
||||
y[i].scales[j] = ls;
|
||||
y[i].scales[j+4] = lm;
|
||||
|
@ -1932,13 +1942,14 @@ redo_pos4:
|
|||
}
|
||||
}
|
||||
|
||||
if (all_zero_lm && !all_positive) {
|
||||
all_positive = 1;
|
||||
//printf("**********red_pos4\n");
|
||||
goto redo_pos4;
|
||||
} else if (all_zero_lm) {
|
||||
//printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min);
|
||||
}
|
||||
//if (all_zero_lm && !all_positive && !ql_loop) {
|
||||
// all_positive = 1;
|
||||
// //printf("**********red_pos4\n");
|
||||
// goto redo_pos4;
|
||||
//}
|
||||
//} else if (all_zero_lm) {
|
||||
// //printf("all_zero_lm, all_pos %d, max_scale %f, max_min %f\n", all_positive, max_scale, max_min);
|
||||
//}
|
||||
|
||||
y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
|
||||
y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
|
||||
|
@ -1962,12 +1973,26 @@ redo_pos4:
|
|||
}
|
||||
}
|
||||
|
||||
float min;
|
||||
float d;
|
||||
lstsq_q_k(q_fit, x, q_m, 32, &min, &d);
|
||||
y[i].d = GGML_FP32_TO_FP16(d);
|
||||
//printf("%d orig: %f %f, ", ql_loop, max_min, max_scale);
|
||||
float min, scale;
|
||||
lstsq_q_k(q_fit, x, q_m, 32, &min, &scale);
|
||||
if (min != min) {
|
||||
printf("min nan\n");
|
||||
}
|
||||
if (scale != scale) {
|
||||
printf("scale nan\n");
|
||||
}
|
||||
//printf("fit: %f %f\n", max_min, max_scale);
|
||||
y[i].d = GGML_FP32_TO_FP16(scale);
|
||||
y[i].dmin = GGML_FP32_TO_FP16(min);
|
||||
//printf("%f %f, %f %f\n", max_min, min * 63.0f, max_scale, scale * 63.0f);
|
||||
max_scale = GGML_FP16_TO_FP32(y[i].d) * 63.0f;
|
||||
max_min = GGML_FP16_TO_FP32(y[i].dmin) * 63.0f;
|
||||
|
||||
ql_loop++;
|
||||
if (ql_loop == 1) {
|
||||
goto quant_loop;
|
||||
}
|
||||
#else
|
||||
const float s_factor = 15.f;
|
||||
float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue