Another minor improvement
This commit is contained in:
parent
4f8dcb1653
commit
e9f1340c20
1 changed files with 17 additions and 11 deletions
28
k_quants.c
28
k_quants.c
|
@ -269,7 +269,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
||||||
}
|
}
|
||||||
|
|
||||||
static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
|
static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
|
||||||
uint8_t * restrict Laux) {
|
uint8_t * restrict Laux, bool use_mad) {
|
||||||
float min = x[0];
|
float min = x[0];
|
||||||
float max = x[0];
|
float max = x[0];
|
||||||
float sum_x = 0, sum_x2 = 0;
|
float sum_x = 0, sum_x2 = 0;
|
||||||
|
@ -288,13 +288,17 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t
|
||||||
float num = sum_x2 * n - sum_x * sum_x * n / (n-1);
|
float num = sum_x2 * n - sum_x * sum_x * n / (n-1);
|
||||||
float iscale = nmax/(max - min);
|
float iscale = nmax/(max - min);
|
||||||
float scale = 1/iscale;
|
float scale = 1/iscale;
|
||||||
float best_mse = 0;
|
float best_mad = 0;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
int l = nearest_int(iscale*(x[i] - min));
|
int l = nearest_int(iscale*(x[i] - min));
|
||||||
L[i] = MAX(0, MIN(nmax, l));
|
L[i] = MAX(0, MIN(nmax, l));
|
||||||
float diff = scale * L[i] + min - x[i];
|
float diff = scale * L[i] + min - x[i];
|
||||||
float w = x[i] * x[i];
|
float w = x[i] * x[i];
|
||||||
best_mse += w * diff * diff;
|
if (use_mad) {
|
||||||
|
best_mad += w * fabsf(diff);
|
||||||
|
} else {
|
||||||
|
best_mad += w * diff * diff;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (num <= 0) {
|
if (num <= 0) {
|
||||||
*the_min = -min;
|
*the_min = -min;
|
||||||
|
@ -318,17 +322,21 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t
|
||||||
this_min = 0;
|
this_min = 0;
|
||||||
this_scale = sqrtf(sum_x2 / sum_l2);
|
this_scale = sqrtf(sum_x2 / sum_l2);
|
||||||
}
|
}
|
||||||
float mse = 0;
|
float mad = 0;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
float diff = this_scale * Laux[i] + this_min - x[i];
|
float diff = this_scale * Laux[i] + this_min - x[i];
|
||||||
float w = x[i] * x[i];
|
float w = x[i] * x[i];
|
||||||
mse += w * diff * diff;
|
if (use_mad) {
|
||||||
|
mad += w * fabsf(diff);
|
||||||
|
} else {
|
||||||
|
mad += w * diff * diff;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (mse < best_mse) {
|
if (mad < best_mad) {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
L[i] = Laux[i];
|
L[i] = Laux[i];
|
||||||
}
|
}
|
||||||
best_mse = mse;
|
best_mad = mad;
|
||||||
scale = this_scale;
|
scale = this_scale;
|
||||||
min = this_min;
|
min = this_min;
|
||||||
}
|
}
|
||||||
|
@ -368,7 +376,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/16; ++j) {
|
for (int j = 0; j < QK_K/16; ++j) {
|
||||||
//scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f);
|
//scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f);
|
||||||
scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux);
|
scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux, false);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
@ -724,7 +732,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
||||||
scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux);
|
scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux, true);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
@ -875,7 +883,6 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
//uint8_t Laux[32];
|
|
||||||
float mins[QK_K/32];
|
float mins[QK_K/32];
|
||||||
float scales[QK_K/32];
|
float scales[QK_K/32];
|
||||||
#else
|
#else
|
||||||
|
@ -891,7 +898,6 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
||||||
//scales[j] = make_qkx2_quants(32, 31, x + 32*j, L + 32*j, &mins[j], Laux);
|
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue