No second least squares pass

This commit is contained in:
Henrik Forstén 2023-12-30 18:31:46 +02:00
parent 8386034e08
commit 5a02328d1f

View file

@ -5,6 +5,7 @@
#include <string.h>
#include <assert.h>
#include <float.h>
#include <stdio.h>
#ifdef __ARM_NEON
@ -457,20 +458,6 @@ static void lstsq_q_1(const uint8_t * restrict q, const float * restrict x, int
}
}
static float lstsq_q_0(const float * restrict q, const float * restrict x, int qk) {
// Least squares fits `d * q = x` for d.
float qs2 = 0.0f;
float xq = 0.0f;
for (int i = 0; i < qk; i++) {
qs2 += q[i]*q[i];
xq += x[i]*q[i];
}
if (qs2 == 0.0f) {
return 0.0f;
}
return xq / qs2;
}
static float lstsq_q_0_u8(const uint8_t * restrict q, const float * restrict x, int qk) {
// Least squares fits `d * q = x` for d.
float qs2 = 0.0f;
@ -1445,18 +1432,20 @@ static void quantize_q_k_1(const float * x, int bits, int scale_bits, int block_
}
}
// Least squares fit min and scale.
float min, scale;
lstsq_q_k(q_fit, x, q_m, block_size, &min, &scale);
// Check for nans.
assert(min == min);
assert(scale == scale);
// Quantize to fp16 for the next pass.
max_scale = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(scale)) * max_group;
max_min = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(min)) * max_group;
if (pass == 0) {
// Least squares fit min and scale.
float min, scale;
lstsq_q_k(q_fit, x, q_m, block_size, &min, &scale);
// Check for nans.
assert(min == min);
assert(scale == scale);
// Quantize to fp16 for the next pass.
max_scale = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(scale)) * max_group;
max_min = GGML_FP16_TO_FP32(GGML_FP32_TO_FP16(min)) * max_group;
*block_scale = GGML_FP32_TO_FP16(scale);
*block_min = GGML_FP32_TO_FP16(min);
*block_scale = GGML_FP32_TO_FP16(scale);
*block_min = GGML_FP32_TO_FP16(min);
}
}
}