Adding a POC dot product for Q4_1 quantization

On my Mac, the direct Q4_1 product is marginally slower
(~69 vs ~55 us for Q4_0). The SIMD-ified ggml version
is now almost 2X slower (~121 us).

On a Ryzen 7950X CPU, the direct product for Q4_1 quantization
is faster than the AVX2 implementation (~60 vs ~62 us).
This commit is contained in:
Iwan Kawrakow 2023-04-18 17:24:45 +02:00
parent 42031dac73
commit baee7684df

View file

@ -36,6 +36,14 @@ typedef struct {
} block_q4_0; } block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
typedef struct {
float d; // delta
float m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
// Copy-pasted from ggml.c // Copy-pasted from ggml.c
#define QK8_0 32 #define QK8_0 32
typedef struct { typedef struct {
@ -125,6 +133,32 @@ inline double dot3(int n, const block_q4_0* x, const float* y) {
return sum; return sum;
} }
inline double dot41(int n, const block_q4_1* x, const float* y) {
const static float kValues[16] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f};
constexpr uint32_t kMask1 = 0x0f0f0f0f;
uint32_t u1, u2;
auto q1 = (const uint8_t*)&u1;
auto q2 = (const uint8_t*)&u2;
double sum = 0;
for (int i=0; i<n; ++i) {
auto u = (const uint32_t*)x->qs;
float s = 0, s1 = 0;
for (int k=0; k<4; ++k) {
u1 = u[k] & kMask1;
u2 = (u[k] >> 4) & kMask1;
s += y[0]*kValues[q1[0]] + y[1]*kValues[q2[0]] +
y[2]*kValues[q1[1]] + y[3]*kValues[q2[1]] +
y[4]*kValues[q1[2]] + y[5]*kValues[q2[2]] +
y[6]*kValues[q1[3]] + y[7]*kValues[q2[3]];
s1 += y[0] + y[1] + y[2] + y[3] + y[4] + y[5] + y[6] + y[7];
y += 8;
}
sum += s*x->d + s1*x->m;
++x;
}
return sum;
}
// Copy-pasted from ggml.c // Copy-pasted from ggml.c
static void quantize_row_q8_0_reference(const float *x, block_q8_0 *y, int k) { static void quantize_row_q8_0_reference(const float *x, block_q8_0 *y, int k) {
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
@ -184,20 +218,30 @@ int main(int argc, char** argv) {
int nloop = argc > 1 ? atoi(argv[1]) : 10; int nloop = argc > 1 ? atoi(argv[1]) : 10;
bool scalar = argc > 2 ? atoi(argv[2]) : false; bool scalar = argc > 2 ? atoi(argv[2]) : false;
bool useQ4_1 = argc > 3 ? atoi(argv[3]) : false;
if (scalar && useQ4_1) {
printf("It is not possible to use Q4_1 quantization and scalar implementations\n");
return 1;
}
std::mt19937 rndm(1234); std::mt19937 rndm(1234);
auto funcs = ggml_internal_get_quantize_fn(GGML_TYPE_Q4_0);
int n4 = kVecSize / QK4_0; n4 = 64*((n4 + 63)/64);
int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64);
std::vector<float> x1(kVecSize), y1(kVecSize); std::vector<float> x1(kVecSize), y1(kVecSize);
std::vector<block_q4_0> q4(n4); int n4 = useQ4_1 ? kVecSize / QK4_1 : kVecSize / QK4_0; n4 = 64*((n4 + 63)/64);
int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64);
auto funcs = useQ4_1 ? ggml_internal_get_quantize_fn(GGML_TYPE_Q4_1) : ggml_internal_get_quantize_fn(GGML_TYPE_Q4_0);
std::vector<block_q4_0> q40;
std::vector<block_q4_1> q41;
if (useQ4_1) q41.resize(n4);
else q40.resize(n4);
std::vector<block_q8_0> q8(n8); std::vector<block_q8_0> q8(n8);
std::vector<int64_t> H(16, 0); std::vector<int64_t> H(16, 0);
double sumt = 0, sumt2 = 0, maxt = 0; double sumt = 0, sumt2 = 0, maxt = 0;
double sumqt = 0, sumqt2 = 0, maxqt = 0; double sumqt = 0, sumqt2 = 0, maxqt = 0;
double sum = 0, sumq = 0; double sum = 0, sumq = 0, exactSum = 0;
for (int iloop=0; iloop<nloop; ++iloop) { for (int iloop=0; iloop<nloop; ++iloop) {
// Fill vector x with random numbers // Fill vector x with random numbers
@ -206,14 +250,22 @@ int main(int argc, char** argv) {
// Fill vector y with random numbers // Fill vector y with random numbers
fillRandomGaussianFloats(y1, rndm); fillRandomGaussianFloats(y1, rndm);
// Compute the exact dot product
for (int k=0; k<kVecSize; ++k) exactSum += x1[k]*y1[k];
// quantize x. // quantize x.
// Note, we do not include this in the timing as in practical application // Note, we do not include this in the timing as in practical application
// we already have the quantized model weights. // we already have the quantized model weights.
ggml_quantize_q4_0(x1.data(), q4.data(), kVecSize, QK4_0, H.data()); if (useQ4_1) {
funcs.quantize_row_q(x1.data(), q41.data(), kVecSize);
} else {
funcs.quantize_row_q(x1.data(), q40.data(), kVecSize);
}
// Now measure time the dot product needs using the "scalar" version above // Now measure time the dot product needs using the "scalar" version above
auto t1 = std::chrono::high_resolution_clock::now(); auto t1 = std::chrono::high_resolution_clock::now();
sum += dot(kVecSize / QK4_0, q4.data(), y1.data()); if (useQ4_1) sum += dot41(kVecSize / QK4_1, q41.data(), y1.data());
else sum += dot(kVecSize / QK4_0, q40.data(), y1.data());
auto t2 = std::chrono::high_resolution_clock::now(); auto t2 = std::chrono::high_resolution_clock::now();
auto t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count(); auto t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
sumt += t; sumt2 += t*t; maxt = std::max(maxt, t); sumt += t; sumt2 += t*t; maxt = std::max(maxt, t);
@ -223,11 +275,12 @@ int main(int argc, char** argv) {
float result; float result;
if (scalar) { if (scalar) {
quantize_row_q8_0_reference(y1.data(), q8.data(), kVecSize); quantize_row_q8_0_reference(y1.data(), q8.data(), kVecSize);
dot_q4_q8(kVecSize, &result, q4.data(), q8.data()); dot_q4_q8(kVecSize, &result, q40.data(), q8.data());
} }
else { else {
funcs.quantize_row_q_dot(y1.data(), q8.data(), kVecSize); funcs.quantize_row_q_dot(y1.data(), q8.data(), kVecSize);
funcs.vec_dot_q(kVecSize, &result, q4.data(), q8.data()); if (useQ4_1) funcs.vec_dot_q(kVecSize, &result, q41.data(), q8.data());
else funcs.vec_dot_q(kVecSize, &result, q40.data(), q8.data());
} }
sumq += result; sumq += result;
t2 = std::chrono::high_resolution_clock::now(); t2 = std::chrono::high_resolution_clock::now();
@ -239,6 +292,8 @@ int main(int argc, char** argv) {
// Report the time (and the average of the dot products so the compiler does not come up with the idea // Report the time (and the average of the dot products so the compiler does not come up with the idea
// of optimizing away the function calls after figuring that the result is not used). // of optimizing away the function calls after figuring that the result is not used).
sum /= nloop; sumq /= nloop; sum /= nloop; sumq /= nloop;
exactSum /= nloop;
printf("Exact result: <dot> = %g\n",exactSum);
printf("<dot> = %g, %g\n",sum,sumq); printf("<dot> = %g, %g\n",sum,sumq);
sumt /= nloop; sumt2 /= nloop; sumt2 -= sumt*sumt; sumt /= nloop; sumt2 /= nloop; sumt2 -= sumt*sumt;
if (sumt2 > 0) sumt2 = sqrt(sumt2); if (sumt2 > 0) sumt2 = sqrt(sumt2);