Adding make_qkx2_quants

With it, we get PPL = 5.8828 for L2-7B Q4_K_S.
This commit is contained in:
Iwan Kawrakow 2023-08-14 16:06:00 +03:00
parent ec9cb753a6
commit 4f8dcb1653
2 changed files with 83 additions and 3 deletions

View file

@ -225,9 +225,13 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
int ntry, float alpha) { int ntry, float alpha) {
float min = x[0]; float min = x[0];
float max = x[0]; float max = x[0];
float sum_x = 0;
float sum_x2 = 0;
for (int i = 1; i < n; ++i) { for (int i = 1; i < n; ++i) {
if (x[i] < min) min = x[i]; if (x[i] < min) min = x[i];
if (x[i] > max) max = x[i]; if (x[i] > max) max = x[i];
sum_x += x[i];
sum_x2 += x[i]*x[i];
} }
if (max == min) { if (max == min) {
for (int i = 0; i < n; ++i) L[i] = 0; for (int i = 0; i < n; ++i) L[i] = 0;
@ -264,6 +268,76 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
return scale; return scale;
} }
static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
uint8_t * restrict Laux) {
float min = x[0];
float max = x[0];
float sum_x = 0, sum_x2 = 0;
for (int i = 1; i < n; ++i) {
if (x[i] < min) min = x[i];
if (x[i] > max) max = x[i];
sum_x += x[i];
sum_x2 += x[i] * x[i];
}
if (min > 0) min = 0;
if (max == min) {
for (int i = 0; i < n; ++i) L[i] = 0;
*the_min = -min;
return 0.f;
}
float num = sum_x2 * n - sum_x * sum_x * n / (n-1);
float iscale = nmax/(max - min);
float scale = 1/iscale;
float best_mse = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
L[i] = MAX(0, MIN(nmax, l));
float diff = scale * L[i] + min - x[i];
float w = x[i] * x[i];
best_mse += w * diff * diff;
}
if (num <= 0) {
*the_min = -min;
return scale;
}
for (int is = -5; is <= 10; ++is) {
iscale = (0.1f*is + nmax)/(max - min);
int sum_l = 0, sum_l2 = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
l = MAX(0, MIN(nmax, l));
Laux[i] = l;
sum_l += l;
sum_l2 += l*l;
}
int den = sum_l2 * n - sum_l * sum_l;
if (den > 0) {
float this_scale = sqrtf(num / den);
float this_min = (sum_x - this_scale * sum_l)/n;
if (this_min > 0) {
this_min = 0;
this_scale = sqrtf(sum_x2 / sum_l2);
}
float mse = 0;
for (int i = 0; i < n; ++i) {
float diff = this_scale * Laux[i] + this_min - x[i];
float w = x[i] * x[i];
mse += w * diff * diff;
}
if (mse < best_mse) {
for (int i = 0; i < n; ++i) {
L[i] = Laux[i];
}
best_mse = mse;
scale = this_scale;
min = this_min;
}
}
}
*the_min = -min;
return scale;
}
#if QK_K == 256 #if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) { if (j < 4) {
@ -282,6 +356,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
const int nb = k / QK_K; const int nb = k / QK_K;
uint8_t L[QK_K]; uint8_t L[QK_K];
uint8_t Laux[16];
float mins[QK_K/16]; float mins[QK_K/16];
float scales[QK_K/16]; float scales[QK_K/16];
@ -292,7 +367,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
float max_scale = 0; // as we are deducting the min, scales are always positive float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0; float max_min = 0;
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; ++j) {
scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f); //scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f);
scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux);
float scale = scales[j]; float scale = scales[j];
if (scale > max_scale) { if (scale > max_scale) {
max_scale = scale; max_scale = scale;
@ -638,6 +714,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
const int nb = k / QK_K; const int nb = k / QK_K;
uint8_t L[QK_K]; uint8_t L[QK_K];
uint8_t Laux[32];
float mins[QK_K/32]; float mins[QK_K/32];
float scales[QK_K/32]; float scales[QK_K/32];
@ -646,7 +723,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
float max_scale = 0; // as we are deducting the min, scales are always positive float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0; float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux);
float scale = scales[j]; float scale = scales[j];
if (scale > max_scale) { if (scale > max_scale) {
max_scale = scale; max_scale = scale;
@ -797,6 +875,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
#if QK_K == 256 #if QK_K == 256
uint8_t L[QK_K]; uint8_t L[QK_K];
//uint8_t Laux[32];
float mins[QK_K/32]; float mins[QK_K/32];
float scales[QK_K/32]; float scales[QK_K/32];
#else #else
@ -812,6 +891,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
float max_min = 0; float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
//scales[j] = make_qkx2_quants(32, 31, x + 32*j, L + 32*j, &mins[j], Laux);
float scale = scales[j]; float scale = scales[j];
if (scale > max_scale) { if (scale > max_scale) {
max_scale = scale; max_scale = scale;

View file

@ -3740,7 +3740,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K; //else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
++i_feed_forward_w2; ++i_feed_forward_w2;
} else if (name.find("attn_output.weight") != std::string::npos) { } else if (name.find("attn_output.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;