diff --git a/k_quants.c b/k_quants.c index b5faefd6c..70c2f2f7b 100644 --- a/k_quants.c +++ b/k_quants.c @@ -269,7 +269,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t } static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, - uint8_t * restrict Laux) { + uint8_t * restrict Laux, bool use_mad) { float min = x[0]; float max = x[0]; float sum_x = 0, sum_x2 = 0; @@ -288,13 +288,17 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t float num = sum_x2 * n - sum_x * sum_x * n / (n-1); float iscale = nmax/(max - min); float scale = 1/iscale; - float best_mse = 0; + float best_mad = 0; for (int i = 0; i < n; ++i) { int l = nearest_int(iscale*(x[i] - min)); L[i] = MAX(0, MIN(nmax, l)); float diff = scale * L[i] + min - x[i]; float w = x[i] * x[i]; - best_mse += w * diff * diff; + if (use_mad) { + best_mad += w * fabsf(diff); + } else { + best_mad += w * diff * diff; + } } if (num <= 0) { *the_min = -min; @@ -318,17 +322,21 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t this_min = 0; this_scale = sqrtf(sum_x2 / sum_l2); } - float mse = 0; + float mad = 0; for (int i = 0; i < n; ++i) { float diff = this_scale * Laux[i] + this_min - x[i]; float w = x[i] * x[i]; - mse += w * diff * diff; + if (use_mad) { + mad += w * fabsf(diff); + } else { + mad += w * diff * diff; + } } - if (mse < best_mse) { + if (mad < best_mad) { for (int i = 0; i < n; ++i) { L[i] = Laux[i]; } - best_mse = mse; + best_mad = mad; scale = this_scale; min = this_min; } @@ -368,7 +376,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict float max_min = 0; for (int j = 0; j < QK_K/16; ++j) { //scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f); - scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux); + scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux, false); float scale = scales[j]; if (scale > max_scale) { max_scale = scale; @@ -724,7 +732,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); - scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux); + scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux, true); float scale = scales[j]; if (scale > max_scale) { max_scale = scale; @@ -875,7 +883,6 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict #if QK_K == 256 uint8_t L[QK_K]; - //uint8_t Laux[32]; float mins[QK_K/32]; float scales[QK_K/32]; #else @@ -891,7 +898,6 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); - //scales[j] = make_qkx2_quants(32, 31, x + 32*j, L + 32*j, &mins[j], Laux); float scale = scales[j]; if (scale > max_scale) { max_scale = scale;