From 5a02328d1f872e2c73bd5e5cbb69a9dc792dd547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Sat, 30 Dec 2023 18:31:46 +0200 Subject: [PATCH] No second least squares pass --- ggml-quants.c | 39 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index eb7da29ad..c2628badf 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -5,6 +5,7 @@ #include #include #include +#include #ifdef __ARM_NEON @@ -457,20 +458,6 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int } } -static float lstsq_q_0(const float * restrict q, const float * restrict x, int qk) { - // Least squares fits `d * q = x` for d. - float qs2 = 0.0f; - float xq = 0.0f; - for (int i = 0; i < qk; i++) { - qs2 += q[i]*q[i]; - xq += x[i]*q[i]; - } - if (qs2 == 0.0f) { - return 0.0f; - } - return xq / qs2; -} - static float lstsq_q_0_u8(const uint8_t * restrict q, const float * restrict x, int qk) { // Least squares fits `d * q = x` for d. float qs2 = 0.0f; @@ -1445,18 +1432,20 @@ static void quantize_q_k_1(const float * x, int bits, int scale_bits, int block_ } } - // Least squares fit min and scale. - float min, scale; - lstsq_q_k(q_fit, x, q_m, block_size, &min, &scale); - // Check for nans. - assert(min == min); - assert(scale == scale); - // Quantize to fp16 for the next pass. - max_scale = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(scale)) * max_group; - max_min = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(min)) * max_group; + if (pass == 0) { + // Least squares fit min and scale. + float min, scale; + lstsq_q_k(q_fit, x, q_m, block_size, &min, &scale); + // Check for nans. + assert(min == min); + assert(scale == scale); + // Quantize to fp16 for the next pass. + max_scale = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(scale)) * max_group; + max_min = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(min)) * max_group; - *block_scale = GGML_FP32_TO_FP16(scale); - *block_min = GGML_FP32_TO_FP16(min); + *block_scale = GGML_FP32_TO_FP16(scale); + *block_min = GGML_FP32_TO_FP16(min); + } } }