iq2_xxs: tuning quantization

This commit is contained in:
Iwan Kawrakow 2024-01-27 12:32:55 +02:00
parent bf9349c610
commit 90faca24fb
2 changed files with 47 additions and 42 deletions

View file

@ -9645,8 +9645,8 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
} }
float best = 0; float best = 0;
float scale = max/(2*kMaxQ-1); float scale = max/(2*kMaxQ-1);
for (int is = -9; is <= 9; ++is) { for (int is = -15; is <= 15; ++is) {
float id = (2*kMaxQ-1+is*0.1f)/max; float id = (2*kMaxQ-1+is*0.2f)/max;
float this_scale = 1/id; float this_scale = 1/id;
for (int k = 0; k < 8; ++k) { for (int k = 0; k < 8; ++k) {
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
@ -9737,49 +9737,51 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
float d = max_scale/31; float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d); y[ibl].d = GGML_FP32_TO_FP16(d);
float id = 1/d; float id = 1/d;
//float sumqx = 0, sumq2 = 0; float sumqx = 0, sumq2 = 0;
for (int ib = 0; ib < QK_K/32; ++ib) { for (int ib = 0; ib < QK_K/32; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]-1)); int l = nearest_int(0.5f*(id*scales[ib]-1));
l = MAX(0, MIN(15, l)); l = MAX(0, MIN(15, l));
scales_and_signs[ib] |= ((uint32_t)l << 28); scales_and_signs[ib] |= ((uint32_t)l << 28);
//const float * xb = xbl + 32*ib; if (false) {
//const float * qw = quant_weights + QK_K*ibl + 32*ib; const float * xb = xbl + 32*ib;
//for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); if (quant_weights) {
//const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib); const float * qw = quant_weights + QK_K*ibl + 32*ib;
//const float db = d * (1 + 2*l); for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
//uint32_t u = 0; } else {
//for (int k = 0; k < 4; ++k) { for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
// const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127); }
// const float * xk = xb + 8*k; const float db = d * (1 + 2*l);
// const float * wk = weight + 8*k; for (int k = 0; k < 8; ++k) {
// const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]); const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
// float best_mse = 0; int best_index = aux8[k]; const float * xk = xb + 4*k;
// for (int j = 0; j < 8; ++j) { const float * wk = weight + 4*k;
// float diff = db * grid[j] * signs[j] - xk[j]; const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
// best_mse += wk[j] * diff * diff; float best_mse = 0; int best_index = q3[8*ib+k];
// } for (int j = 0; j < 4; ++j) {
// for (int idx = 0; idx < 256; ++idx) { float diff = db * grid[j] * signs[j] - xk[j];
// grid = (const uint8_t *)(kgrid_q2xs + idx); best_mse += wk[j] * diff * diff;
// float mse = 0; }
// for (int j = 0; j < 8; ++j) { for (int idx = 0; idx < 256; ++idx) {
// float diff = db * grid[j] * signs[j] - xk[j]; grid = (const uint8_t *)(kgrid_q3xs + idx);
// mse += wk[j] * diff * diff; float mse = 0;
// } for (int j = 0; j < 4; ++j) {
// if (mse < best_mse) { float diff = db * grid[j] * signs[j] - xk[j];
// best_mse = mse; best_index = idx; mse += wk[j] * diff * diff;
// } }
// } if (mse < best_mse) {
// u |= (best_index << 8*k); best_mse = mse; best_index = idx;
// grid = (const uint8_t *)(kgrid_q2xs + best_index); }
// //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]); }
// for (int j = 0; j < 8; ++j) { q3[8*ib+k] = best_index;
// float q = db * grid[j] * signs[j]; grid = (const uint8_t *)(kgrid_q3xs + best_index);
// sumqx += wk[j] * q * xk[j]; for (int j = 0; j < 4; ++j) {
// sumq2 += wk[j] * q * q; float q = db * grid[j] * signs[j];
// } sumqx += wk[j] * q * xk[j];
//} sumq2 += wk[j] * q * q;
//q2[2*ib] = u; }
//if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2); }
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
}
} }
memcpy(y[ibl].qs, q3, 3*QK_K/8); memcpy(y[ibl].qs, q3, 3*QK_K/8);
} }

View file

@ -8992,6 +8992,9 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q4_K; new_type = GGML_TYPE_Q4_K;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
} }
@ -9026,7 +9029,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
int i_layer = info.first, n_layer = info.second; int i_layer = info.first, n_layer = info.second;
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {// || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K; if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {