Adding make_qkx2_quants
With it, we get PPL = 5.8828 for L2-7B Q4_K_S.
This commit is contained in:
parent
ec9cb753a6
commit
4f8dcb1653
2 changed files with 83 additions and 3 deletions
84
k_quants.c
84
k_quants.c
|
@ -225,9 +225,13 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
||||||
int ntry, float alpha) {
|
int ntry, float alpha) {
|
||||||
float min = x[0];
|
float min = x[0];
|
||||||
float max = x[0];
|
float max = x[0];
|
||||||
|
float sum_x = 0;
|
||||||
|
float sum_x2 = 0;
|
||||||
for (int i = 1; i < n; ++i) {
|
for (int i = 1; i < n; ++i) {
|
||||||
if (x[i] < min) min = x[i];
|
if (x[i] < min) min = x[i];
|
||||||
if (x[i] > max) max = x[i];
|
if (x[i] > max) max = x[i];
|
||||||
|
sum_x += x[i];
|
||||||
|
sum_x2 += x[i]*x[i];
|
||||||
}
|
}
|
||||||
if (max == min) {
|
if (max == min) {
|
||||||
for (int i = 0; i < n; ++i) L[i] = 0;
|
for (int i = 0; i < n; ++i) L[i] = 0;
|
||||||
|
@ -264,6 +268,76 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
||||||
return scale;
|
return scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
|
||||||
|
uint8_t * restrict Laux) {
|
||||||
|
float min = x[0];
|
||||||
|
float max = x[0];
|
||||||
|
float sum_x = 0, sum_x2 = 0;
|
||||||
|
for (int i = 1; i < n; ++i) {
|
||||||
|
if (x[i] < min) min = x[i];
|
||||||
|
if (x[i] > max) max = x[i];
|
||||||
|
sum_x += x[i];
|
||||||
|
sum_x2 += x[i] * x[i];
|
||||||
|
}
|
||||||
|
if (min > 0) min = 0;
|
||||||
|
if (max == min) {
|
||||||
|
for (int i = 0; i < n; ++i) L[i] = 0;
|
||||||
|
*the_min = -min;
|
||||||
|
return 0.f;
|
||||||
|
}
|
||||||
|
float num = sum_x2 * n - sum_x * sum_x * n / (n-1);
|
||||||
|
float iscale = nmax/(max - min);
|
||||||
|
float scale = 1/iscale;
|
||||||
|
float best_mse = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale*(x[i] - min));
|
||||||
|
L[i] = MAX(0, MIN(nmax, l));
|
||||||
|
float diff = scale * L[i] + min - x[i];
|
||||||
|
float w = x[i] * x[i];
|
||||||
|
best_mse += w * diff * diff;
|
||||||
|
}
|
||||||
|
if (num <= 0) {
|
||||||
|
*the_min = -min;
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
for (int is = -5; is <= 10; ++is) {
|
||||||
|
iscale = (0.1f*is + nmax)/(max - min);
|
||||||
|
int sum_l = 0, sum_l2 = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int l = nearest_int(iscale*(x[i] - min));
|
||||||
|
l = MAX(0, MIN(nmax, l));
|
||||||
|
Laux[i] = l;
|
||||||
|
sum_l += l;
|
||||||
|
sum_l2 += l*l;
|
||||||
|
}
|
||||||
|
int den = sum_l2 * n - sum_l * sum_l;
|
||||||
|
if (den > 0) {
|
||||||
|
float this_scale = sqrtf(num / den);
|
||||||
|
float this_min = (sum_x - this_scale * sum_l)/n;
|
||||||
|
if (this_min > 0) {
|
||||||
|
this_min = 0;
|
||||||
|
this_scale = sqrtf(sum_x2 / sum_l2);
|
||||||
|
}
|
||||||
|
float mse = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float diff = this_scale * Laux[i] + this_min - x[i];
|
||||||
|
float w = x[i] * x[i];
|
||||||
|
mse += w * diff * diff;
|
||||||
|
}
|
||||||
|
if (mse < best_mse) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
L[i] = Laux[i];
|
||||||
|
}
|
||||||
|
best_mse = mse;
|
||||||
|
scale = this_scale;
|
||||||
|
min = this_min;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*the_min = -min;
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||||
if (j < 4) {
|
if (j < 4) {
|
||||||
|
@ -282,6 +356,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
|
uint8_t Laux[16];
|
||||||
float mins[QK_K/16];
|
float mins[QK_K/16];
|
||||||
float scales[QK_K/16];
|
float scales[QK_K/16];
|
||||||
|
|
||||||
|
@ -292,7 +367,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
||||||
float max_scale = 0; // as we are deducting the min, scales are always positive
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/16; ++j) {
|
for (int j = 0; j < QK_K/16; ++j) {
|
||||||
scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f);
|
//scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f);
|
||||||
|
scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
@ -638,6 +714,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
|
uint8_t Laux[32];
|
||||||
float mins[QK_K/32];
|
float mins[QK_K/32];
|
||||||
float scales[QK_K/32];
|
float scales[QK_K/32];
|
||||||
|
|
||||||
|
@ -646,7 +723,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
||||||
float max_scale = 0; // as we are deducting the min, scales are always positive
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
||||||
|
scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
@ -797,6 +875,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
uint8_t L[QK_K];
|
uint8_t L[QK_K];
|
||||||
|
//uint8_t Laux[32];
|
||||||
float mins[QK_K/32];
|
float mins[QK_K/32];
|
||||||
float scales[QK_K/32];
|
float scales[QK_K/32];
|
||||||
#else
|
#else
|
||||||
|
@ -812,6 +891,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
||||||
float max_min = 0;
|
float max_min = 0;
|
||||||
for (int j = 0; j < QK_K/32; ++j) {
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
||||||
|
//scales[j] = make_qkx2_quants(32, 31, x + 32*j, L + 32*j, &mins[j], Laux);
|
||||||
float scale = scales[j];
|
float scale = scales[j];
|
||||||
if (scale > max_scale) {
|
if (scale > max_scale) {
|
||||||
max_scale = scale;
|
max_scale = scale;
|
||||||
|
|
|
@ -3740,7 +3740,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||||
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
|
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
|
//else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
|
||||||
++i_feed_forward_w2;
|
++i_feed_forward_w2;
|
||||||
} else if (name.find("attn_output.weight") != std::string::npos) {
|
} else if (name.find("attn_output.weight") != std::string::npos) {
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue