Improve Q4_0 MSE
Somehow I had it hard-wired in my brain that quants need to be in -7...7 to be comparable to the original Q4_0. But this is clearly not the case, and if we relax this requirement this simple change brings the rmse down to 0.001966 at the expense of a somewhat longer computation (~67 seconds vs 49 seconds for the 7B model on M2 Max). Perplexity test is still running but it looks like the improvement compared to the previous version will be quite modest ~0.03) despite the significant improvement in MSE. The change does not affect Q4_1 as there we already use the full range of 16 possible int values.
This commit is contained in:
parent
b6df974577
commit
931ae36050
1 changed files with 15 additions and 5 deletions
|
@ -23,7 +23,7 @@ inline int toNearestInt(float fval) {
|
||||||
return (i & 0x007fffff) - 0x00400000;
|
return (i & 0x007fffff) - 0x00400000;
|
||||||
}
|
}
|
||||||
|
|
||||||
float kQuantize0(int n, const float* X, int8_t* L, std::vector<std::pair<float,int>>& work, int nmin, int nmax) {
|
std::pair<float, float> kQuantize0(int n, const float* X, int8_t* L, std::vector<std::pair<float,int>>& work, int nmin, int nmax) {
|
||||||
work.clear();
|
work.clear();
|
||||||
work.reserve(n*(nmax+2));
|
work.reserve(n*(nmax+2));
|
||||||
float max = 0; int imax = -1;
|
float max = 0; int imax = -1;
|
||||||
|
@ -33,7 +33,7 @@ float kQuantize0(int n, const float* X, int8_t* L, std::vector<std::pair<float,i
|
||||||
}
|
}
|
||||||
if (imax < 0) { // all X are zero
|
if (imax < 0) { // all X are zero
|
||||||
for (int i=0; i<n; ++i) L[i] = 0;
|
for (int i=0; i<n; ++i) L[i] = 0;
|
||||||
return 1.f;
|
return {1.f, 0.f};
|
||||||
}
|
}
|
||||||
float maxi = 1/max;
|
float maxi = 1/max;
|
||||||
int kmin, kmax;
|
int kmin, kmax;
|
||||||
|
@ -46,7 +46,7 @@ float kQuantize0(int n, const float* X, int8_t* L, std::vector<std::pair<float,i
|
||||||
}
|
}
|
||||||
auto df0 = suml20/scale0 - sumlx0;
|
auto df0 = suml20/scale0 - sumlx0;
|
||||||
if (df0 > 0) {
|
if (df0 > 0) {
|
||||||
kmin = nmax-2; kmax = nmax + 1;
|
kmin = nmax-2; kmax = nmax+1;
|
||||||
} else {
|
} else {
|
||||||
kmin = nmax/2; kmax = nmax+1;
|
kmin = nmax/2; kmax = nmax+1;
|
||||||
}
|
}
|
||||||
|
@ -97,7 +97,7 @@ float kQuantize0(int n, const float* X, int8_t* L, std::vector<std::pair<float,i
|
||||||
lasts = s;
|
lasts = s;
|
||||||
}
|
}
|
||||||
for (int i=0; i<n; ++i) L[i] = std::max(nmin, std::min(nmax, toNearestInt(bests*X[i])));
|
for (int i=0; i<n; ++i) L[i] = std::max(nmin, std::min(nmax, toNearestInt(bests*X[i])));
|
||||||
return bestSumlx/bestSuml2;
|
return {bestSumlx/bestSuml2, bestSumlx*bestSumlx/bestSuml2};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<float, float> kQuantize1(int n, const float* X, int8_t* L, std::vector<float>& tmpX,
|
std::pair<float, float> kQuantize1(int n, const float* X, int8_t* L, std::vector<float>& tmpX,
|
||||||
|
@ -137,7 +137,17 @@ void kQuantizeQ4(const float* X, void* buffer, int k, int type) {
|
||||||
auto processOne = [type] (const float* X, int8_t* L, char* y, std::vector<std::pair<float, int>>& work, std::vector<float>& tmpX) {
|
auto processOne = [type] (const float* X, int8_t* L, char* y, std::vector<std::pair<float, int>>& work, std::vector<float>& tmpX) {
|
||||||
auto q = (uint8_t*)y;
|
auto q = (uint8_t*)y;
|
||||||
if (type == 0) {
|
if (type == 0) {
|
||||||
float scale = kQuantize0(QK, X, L, work, -7, 7);
|
if (int(tmpX.size()) < QK) tmpX.resize(QK);
|
||||||
|
auto r1 = kQuantize0(QK, X, L, work, -8, 7);
|
||||||
|
for (int i=0; i<QK; ++i) tmpX[i] = -X[i];
|
||||||
|
int8_t L2[QK];
|
||||||
|
auto r2 = kQuantize0(QK, tmpX.data(), L2, work, -8, 7);
|
||||||
|
float scale = r1.first;
|
||||||
|
if (r2.second > r1.first) {
|
||||||
|
scale = -r2.first;
|
||||||
|
std::memcpy(L, L2, QK);
|
||||||
|
}
|
||||||
|
//float scale = kQuantize0(QK, X, L, work, -7, 7);
|
||||||
std::memcpy(q, &scale, sizeof(scale)); q += sizeof(scale);
|
std::memcpy(q, &scale, sizeof(scale)); q += sizeof(scale);
|
||||||
for (int k=0; k<QK/2; ++k) q[k] = (L[2*k] + 8) | ((L[2*k+1] + 8) << 4);
|
for (int k=0; k<QK/2; ++k) q[k] = (L[2*k] + 8) | ((L[2*k+1] + 8) << 4);
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue