diff --git a/ggml-quants.c b/ggml-quants.c index aa793ffc5..06481048f 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -1892,6 +1892,35 @@ redo_pos4: uint8_t lm = nearest_int(inv_min*mins[j]); ls = MIN(63, ls); lm = MIN(63, lm); + float rms = 0.0f; + float rms2 = 0.0f; + float rms3 = 0.0f; + uint8_t lm2 = MIN(63, MAX(0, lm - 1)); + uint8_t lm3 = MIN(63, MAX(0, lm + 1)); + const float d1 = max_scale / 63.0f; + const float dmin1 = max_min / 63.0f; + for (int ii = 0; ii < 32; ii++) { + const float d = d1 * ls; + const float dm1 = dmin1 * lm; + const float dm2 = dmin1 * lm2; + const float dm3 = dmin1 * lm3; + if (!d) continue; + int l1 = nearest_int((x[32*j + ii] + dm1)/d); + l1 = MAX(0, MIN(15, l1)); + int l2 = nearest_int((x[32*j + ii] + dm2)/d); + l2 = MAX(0, MIN(15, l2)); + int l3 = nearest_int((x[32*j + ii] + dm3)/d); + l3 = MAX(0, MIN(15, l3)); + rms += ((d*l1 - dm1) - x[32*j + ii]) * ((d*l1 - dm1) - x[32*j + ii]); + rms2 += ((d*l2 - dm2) - x[32*j + ii]) * ((d*l2 - dm2) - x[32*j + ii]); + rms3 += ((d*l3 - dm3) - x[32*j + ii]) * ((d*l3 - dm3) - x[32*j + ii]); + } + if (rms2 < rms && rms2 < rms3) { + lm = lm2; + } + if (rms3 < rms && rms3 < rms2) { + lm = lm3; + } if (lm != 0) all_zero_lm = 0; if (j < 4) { y[i].scales[j] = ls;