From baee7684df89451fadb7b48c7955ee6a807f0b61 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 18 Apr 2023 17:24:45 +0200 Subject: [PATCH] Adding a POC dot product for Q4_1 quantization On my Mac, the direct Q4_1 product is marginally slower (~69 vs ~55 us for Q4_0). The SIMD-ified ggml version is now almost 2X slower (~121 us). On a Ryzen 7950X CPU, the direct product for Q4_1 quantization is faster than the AVX2 implementation (~60 vs ~62 us). --- pocs/vdot/vdot.cpp | 75 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/pocs/vdot/vdot.cpp b/pocs/vdot/vdot.cpp index b89a42145..26bf50c9a 100644 --- a/pocs/vdot/vdot.cpp +++ b/pocs/vdot/vdot.cpp @@ -36,6 +36,14 @@ typedef struct { } block_q4_0; static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); +#define QK4_1 32 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + // Copy-pasted from ggml.c #define QK8_0 32 typedef struct { @@ -125,6 +133,32 @@ inline double dot3(int n, const block_q4_0* x, const float* y) { return sum; } +inline double dot41(int n, const block_q4_1* x, const float* y) { + const static float kValues[16] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + constexpr uint32_t kMask1 = 0x0f0f0f0f; + uint32_t u1, u2; + auto q1 = (const uint8_t*)&u1; + auto q2 = (const uint8_t*)&u2; + double sum = 0; + for (int i=0; iqs; + float s = 0, s1 = 0; + for (int k=0; k<4; ++k) { + u1 = u[k] & kMask1; + u2 = (u[k] >> 4) & kMask1; + s += y[0]*kValues[q1[0]] + y[1]*kValues[q2[0]] + + y[2]*kValues[q1[1]] + y[3]*kValues[q2[1]] + + y[4]*kValues[q1[2]] + y[5]*kValues[q2[2]] + + y[6]*kValues[q1[3]] + y[7]*kValues[q2[3]]; + s1 += y[0] + y[1] + y[2] + y[3] + y[4] + y[5] + y[6] + y[7]; + y += 8; + } + sum += s*x->d + s1*x->m; + ++x; + } + return sum; +} + // Copy-pasted from ggml.c static void quantize_row_q8_0_reference(const float *x, block_q8_0 *y, int k) { assert(k % QK8_0 == 0); @@ -184,20 +218,30 @@ int main(int argc, char** argv) { int nloop = argc > 1 ? atoi(argv[1]) : 10; bool scalar = argc > 2 ? atoi(argv[2]) : false; + bool useQ4_1 = argc > 3 ? atoi(argv[3]) : false; + + if (scalar && useQ4_1) { + printf("It is not possible to use Q4_1 quantization and scalar implementations\n"); + return 1; + } std::mt19937 rndm(1234); - auto funcs = ggml_internal_get_quantize_fn(GGML_TYPE_Q4_0); - - int n4 = kVecSize / QK4_0; n4 = 64*((n4 + 63)/64); - int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64); std::vector x1(kVecSize), y1(kVecSize); - std::vector q4(n4); + int n4 = useQ4_1 ? kVecSize / QK4_1 : kVecSize / QK4_0; n4 = 64*((n4 + 63)/64); + int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64); + + auto funcs = useQ4_1 ? ggml_internal_get_quantize_fn(GGML_TYPE_Q4_1) : ggml_internal_get_quantize_fn(GGML_TYPE_Q4_0); + + std::vector q40; + std::vector q41; + if (useQ4_1) q41.resize(n4); + else q40.resize(n4); std::vector q8(n8); std::vector H(16, 0); double sumt = 0, sumt2 = 0, maxt = 0; double sumqt = 0, sumqt2 = 0, maxqt = 0; - double sum = 0, sumq = 0; + double sum = 0, sumq = 0, exactSum = 0; for (int iloop=0; iloop(t2-t1).count(); sumt += t; sumt2 += t*t; maxt = std::max(maxt, t); @@ -223,11 +275,12 @@ int main(int argc, char** argv) { float result; if (scalar) { quantize_row_q8_0_reference(y1.data(), q8.data(), kVecSize); - dot_q4_q8(kVecSize, &result, q4.data(), q8.data()); + dot_q4_q8(kVecSize, &result, q40.data(), q8.data()); } else { funcs.quantize_row_q_dot(y1.data(), q8.data(), kVecSize); - funcs.vec_dot_q(kVecSize, &result, q4.data(), q8.data()); + if (useQ4_1) funcs.vec_dot_q(kVecSize, &result, q41.data(), q8.data()); + else funcs.vec_dot_q(kVecSize, &result, q40.data(), q8.data()); } sumq += result; t2 = std::chrono::high_resolution_clock::now(); @@ -239,6 +292,8 @@ int main(int argc, char** argv) { // Report the time (and the average of the dot products so the compiler does not come up with the idea // of optimizing away the function calls after figuring that the result is not used). sum /= nloop; sumq /= nloop; + exactSum /= nloop; + printf("Exact result: = %g\n",exactSum); printf(" = %g, %g\n",sum,sumq); sumt /= nloop; sumt2 /= nloop; sumt2 -= sumt*sumt; if (sumt2 > 0) sumt2 = sqrt(sumt2);