iq2_xxs: tuning quantization

This commit is contained in:
Iwan Kawrakow 2024-01-27 12:32:55 +02:00
parent bf9349c610
commit 90faca24fb
2 changed files with 47 additions and 42 deletions

View file

@ -9645,8 +9645,8 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
}
float best = 0;
float scale = max/(2*kMaxQ-1);
for (int is = -9; is <= 9; ++is) {
float id = (2*kMaxQ-1+is*0.1f)/max;
for (int is = -15; is <= 15; ++is) {
float id = (2*kMaxQ-1+is*0.2f)/max;
float this_scale = 1/id;
for (int k = 0; k < 8; ++k) {
for (int i = 0; i < 4; ++i) {
@ -9737,49 +9737,51 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d);
float id = 1/d;
//float sumqx = 0, sumq2 = 0;
float sumqx = 0, sumq2 = 0;
for (int ib = 0; ib < QK_K/32; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]-1));
l = MAX(0, MIN(15, l));
scales_and_signs[ib] |= ((uint32_t)l << 28);
//const float * xb = xbl + 32*ib;
//const float * qw = quant_weights + QK_K*ibl + 32*ib;
//for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
//const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
//const float db = d * (1 + 2*l);
//uint32_t u = 0;
//for (int k = 0; k < 4; ++k) {
// const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
// const float * xk = xb + 8*k;
// const float * wk = weight + 8*k;
// const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
// float best_mse = 0; int best_index = aux8[k];
// for (int j = 0; j < 8; ++j) {
// float diff = db * grid[j] * signs[j] - xk[j];
// best_mse += wk[j] * diff * diff;
// }
// for (int idx = 0; idx < 256; ++idx) {
// grid = (const uint8_t *)(kgrid_q2xs + idx);
// float mse = 0;
// for (int j = 0; j < 8; ++j) {
// float diff = db * grid[j] * signs[j] - xk[j];
// mse += wk[j] * diff * diff;
// }
// if (mse < best_mse) {
// best_mse = mse; best_index = idx;
// }
// }
// u |= (best_index << 8*k);
// grid = (const uint8_t *)(kgrid_q2xs + best_index);
// //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
// for (int j = 0; j < 8; ++j) {
// float q = db * grid[j] * signs[j];
// sumqx += wk[j] * q * xk[j];
// sumq2 += wk[j] * q * q;
// }
//}
//q2[2*ib] = u;
//if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
if (false) {
const float * xb = xbl + 32*ib;
if (quant_weights) {
const float * qw = quant_weights + QK_K*ibl + 32*ib;
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
} else {
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
}
const float db = d * (1 + 2*l);
for (int k = 0; k < 8; ++k) {
const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
const float * xk = xb + 4*k;
const float * wk = weight + 4*k;
const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
float best_mse = 0; int best_index = q3[8*ib+k];
for (int j = 0; j < 4; ++j) {
float diff = db * grid[j] * signs[j] - xk[j];
best_mse += wk[j] * diff * diff;
}
for (int idx = 0; idx < 256; ++idx) {
grid = (const uint8_t *)(kgrid_q3xs + idx);
float mse = 0;
for (int j = 0; j < 4; ++j) {
float diff = db * grid[j] * signs[j] - xk[j];
mse += wk[j] * diff * diff;
}
if (mse < best_mse) {
best_mse = mse; best_index = idx;
}
}
q3[8*ib+k] = best_index;
grid = (const uint8_t *)(kgrid_q3xs + best_index);
for (int j = 0; j < 4; ++j) {
float q = db * grid[j] * signs[j];
sumqx += wk[j] * q * xk[j];
sumq2 += wk[j] * q * q;
}
}
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
}
}
memcpy(y[ibl].qs, q3, 3*QK_K/8);
}

View file

@ -8992,6 +8992,9 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
}
@ -9026,7 +9029,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
int i_layer = info.first, n_layer = info.second;
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {// || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {