ggml : fix Q8_0 to use 255 values out of 256

This commit is contained in:
Georgi Gerganov 2023-04-25 23:23:05 +03:00
parent 91bfa51dca
commit 4ddb983a02
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 6 additions and 10 deletions

10
ggml.c
View file

@ -1295,17 +1295,13 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int l = 0; l < QK8_0; l++) {
const float v = x[i*QK8_0 + l];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
amax = MAX(amax, fabsf(v));
}
const float d = max / -128;
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
@ -1313,7 +1309,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
for (int l = 0; l < QK8_0; ++l) {
const float v0 = x[i*QK8_0 + l]*id;
y[i].qs[l] = MIN(127, roundf(v0));
y[i].qs[l] = roundf(v0);
}
}
}

View file

@ -72,7 +72,7 @@ float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * t
std::vector<uint8_t> tmp_q1(2*test_size);
std::vector<uint8_t> tmp_q2(2*test_size);
qfns.quantize_row_q(test_data1, tmp_q1.data(), test_size);
qfns.quantize_row_q (test_data1, tmp_q1.data(), test_size);
qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size);
float result = INFINITY;
@ -125,7 +125,7 @@ int main(int argc, char * argv[]) {
failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR);
num_failed += failed;
if (failed || verbose) {
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
}
const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
@ -139,7 +139,7 @@ int main(int argc, char * argv[]) {
failed = !(vec_dot_error < MAX_DOT_PRODUCT_ERROR);
num_failed += failed;
if (failed || verbose) {
printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
}
}
}