This commit is contained in:
Henrik Forstén 2024-01-01 13:59:42 +02:00
parent 5a02328d1f
commit 4b06507172

View file

@ -1325,16 +1325,16 @@ static void quantize_q_k_1(const float * x, int bits, int scale_bits, int block_
// If all the weight are positive we can invert the sign of min. // If all the weight are positive we can invert the sign of min.
// Otherwise blocks with all positive weights need to be quantized with zero // Otherwise blocks with all positive weights need to be quantized with zero
// min, because min scale is unsigned. // min, because min scale is unsigned.
int all_positive = 1; bool all_positive = true;
for (int j = 0; j < QK_K; j++) { for (int j = 0; j < QK_K; j++) {
if (x[j] < 0.0f) { if (x[j] < 0.0f) {
all_positive = 0; all_positive = false;
break; break;
} }
} }
float scales[QK_K]; float scales[QK_K/block_size];
float mins[QK_K]; float mins[QK_K/block_size];
for (int j = 0; j < QK_K/block_size; j++) { for (int j = 0; j < QK_K/block_size; j++) {
uint8_t q[QK_K/block_size]; uint8_t q[QK_K/block_size];
@ -1343,7 +1343,7 @@ static void quantize_q_k_1(const float * x, int bits, int scale_bits, int block_
// Flip the sign because quantize_1 assumes that min is added, but min // Flip the sign because quantize_1 assumes that min is added, but min
// is subtracted in k-quants. // is subtracted in k-quants.
mins[j] = -mins[j]; mins[j] = -mins[j];
if (!all_positive && mins[j] < 0) { if ((!all_positive && mins[j] < 0) || (all_positive && mins[j] > 0)) {
// All weights are positive in this block, but some blocks have // All weights are positive in this block, but some blocks have
// negative weights. Find new least squares scale with zero min. // negative weights. Find new least squares scale with zero min.
mins[j] = 0.0f; mins[j] = 0.0f;
@ -1366,18 +1366,18 @@ static void quantize_q_k_1(const float * x, int bits, int scale_bits, int block_
// Increasing passes would decrease RMS error by miniscule amount with // Increasing passes would decrease RMS error by miniscule amount with
// drawback of taking more time. // drawback of taking more time.
for(int pass = 0; pass < 2; pass++) { for(int pass = 0; pass < 2; pass++) {
float inv_scale = max_scale == 0.0f ? 0.0f : max_group/max_scale; float inv_scale = max_scale == 0.0f ? 0.0f : max_group/max_scale;
float inv_min = max_min == 0.0f ? 0.0f : max_group/max_min; float inv_min = max_min == 0.0f ? 0.0f : max_group/max_min;
float block_d = max_scale/max_group;
float block_dmin = max_min/max_group;
for (int j = 0; j < QK_K/block_size; ++j) { for (int j = 0; j < QK_K/block_size; ++j) {
uint8_t ls = nearest_int(inv_scale*scales[j]); uint8_t ls = MAX(0, nearest_int(inv_scale*scales[j]));
uint8_t lm = nearest_int(inv_min*mins[j]); uint8_t lm = MAX(0, nearest_int(inv_min*mins[j]));
uint8_t best_lm = lm;
uint8_t best_ls = ls;
ls = MIN(max_group, ls); ls = MIN(max_group, ls);
lm = MIN(max_group, lm); lm = MIN(max_group, lm);
uint8_t best_lm = lm;
uint8_t best_ls = ls;
float best_rms = FLT_MAX; float best_rms = FLT_MAX;
const float d1 = max_scale / max_group;
const float dmin1 = max_min / max_group;
int limit = 1; int limit = 1;
// Increase limit for minor RMS error decrease while increasing the // Increase limit for minor RMS error decrease while increasing the
// quantization run time. // quantization run time.
@ -1390,14 +1390,14 @@ static void quantize_q_k_1(const float * x, int bits, int scale_bits, int block_
for (int lmt = MAX(0, lm-limit); lmt <= MIN(max_group, lm+limit); lmt++) { for (int lmt = MAX(0, lm-limit); lmt <= MIN(max_group, lm+limit); lmt++) {
float rms = 0.0f; float rms = 0.0f;
for (int ii = 0; ii < block_size; ii++) { for (int ii = 0; ii < block_size; ii++) {
const float d = d1 * lst; const float d = block_d * lst;
const float dm1 = dmin1 * lmt; const float dm = block_dmin * lmt;
int l1 = 0; int l1 = 0;
if (d) { if (d) {
l1 = nearest_int((x[block_size*j + ii] + dm1)/d); l1 = nearest_int((x[block_size*j + ii] + dm)/d);
l1 = MAX(0, MIN((1 << bits) - 1, l1)); l1 = MAX(0, MIN((1 << bits) - 1, l1));
} }
const float e = (d*l1 - dm1) - x[block_size*j + ii]; const float e = (d*l1 - dm) - x[block_size*j + ii];
rms += e*e; rms += e*e;
} }
if (rms < best_rms) { if (rms < best_rms) {
@ -1411,8 +1411,6 @@ static void quantize_q_k_1(const float * x, int bits, int scale_bits, int block_
block_mins[j] = best_lm; block_mins[j] = best_lm;
} }
float block_d = max_scale/max_group;
float block_dmin = max_min/max_group;
float q_fit[QK_K]; float q_fit[QK_K];
float q_m[QK_K/block_size]; float q_m[QK_K/block_size];