diff --git a/k_quants.c b/k_quants.c index 29944aa25..82bf81697 100644 --- a/k_quants.c +++ b/k_quants.c @@ -847,8 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict uint8_t L[QK_K]; float mins[QK_K/32]; float scales[QK_K/32]; - //float weights[32]; - //uint8_t Laux[32]; + float weights[32]; + uint8_t Laux[32]; #else int8_t L[QK_K]; float scales[QK_K/16]; @@ -861,12 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { - scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); - //float sum_x2 = 0; - //for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; - //float av_x = sqrtf(sum_x2/32); - //for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); - //scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false); + //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false); float scale = scales[j]; if (scale > max_scale) { max_scale = scale;