From 859fee6dfb00fab7ce6bc215b4adae78d82f4759 Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Wed, 26 Apr 2023 18:43:27 +0200 Subject: [PATCH 1/6] quantize : use `map` to assign quantization type from `string` (#1191) instead of `int` (while `int` option still being supported) This allows the following usage: `./quantize ggml-model-f16.bin ggml-model-q4_0.bin q4_0` instead of: `./quantize ggml-model-f16.bin ggml-model-q4_0.bin 2` --- .devops/tools.sh | 2 +- README.md | 4 ++-- examples/quantize/quantize.cpp | 30 ++++++++++++++++++++++++------ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/.devops/tools.sh b/.devops/tools.sh index b0196b60d..ece9e4efa 100755 --- a/.devops/tools.sh +++ b/.devops/tools.sh @@ -23,7 +23,7 @@ elif [[ $arg1 == '--all-in-one' || $arg1 == '-a' ]]; then echo "Skip model quantization, it already exists: ${i/f16/q4_0}" else echo "Converting PTH to GGML: $i into ${i/f16/q4_0}..." - ./quantize "$i" "${i/f16/q4_0}" 2 + ./quantize "$i" "${i/f16/q4_0}" q4_0 fi done else diff --git a/README.md b/README.md index 44cf72124..509df61a1 100644 --- a/README.md +++ b/README.md @@ -203,8 +203,8 @@ python3 -m pip install -r requirements.txt # convert the 7B model to ggml FP16 format python3 convert.py models/7B/ -# quantize the model to 4-bits (using method 2 = q4_0) -./quantize ./models/7B/ggml-model-f16.bin ./models/7B/ggml-model-q4_0.bin 2 +# quantize the model to 4-bits (using q4_0 method) +./quantize ./models/7B/ggml-model-f16.bin ./models/7B/ggml-model-q4_0.bin q4_0 # run the inference ./main -m ./models/7B/ggml-model-q4_0.bin -n 128 diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index ad39a805d..ec7f91aae 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -2,8 +2,17 @@ #include "llama.h" #include +#include #include +static const std::map LLAMA_FTYPE_MAP = { + {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0}, + {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, + {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2}, + {"q4_3", LLAMA_FTYPE_MOSTLY_Q4_3}, + {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, +}; + // usage: // ./quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type // @@ -12,11 +21,9 @@ int main(int argc, char ** argv) { if (argc < 4) { fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type [nthread]\n", argv[0]); - fprintf(stderr, " type = %d - q4_0\n", LLAMA_FTYPE_MOSTLY_Q4_0); - fprintf(stderr, " type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1); - fprintf(stderr, " type = %d - q4_2\n", LLAMA_FTYPE_MOSTLY_Q4_2); - fprintf(stderr, " type = %d - q4_3\n", LLAMA_FTYPE_MOSTLY_Q4_3); - fprintf(stderr, " type = %d - q8_0\n", LLAMA_FTYPE_MOSTLY_Q8_0); + for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { + fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); + } return 1; } @@ -30,7 +37,18 @@ int main(int argc, char ** argv) { const std::string fname_inp = argv[1]; const std::string fname_out = argv[2]; - const enum llama_ftype ftype = (enum llama_ftype)atoi(argv[3]); + enum llama_ftype ftype; + if (argv[3][0] == 'q') { + auto it = LLAMA_FTYPE_MAP.find(argv[3]); + if (it == LLAMA_FTYPE_MAP.end()) { + fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, argv[3]); + return 1; + } + ftype = it->second; + } else { + ftype = (enum llama_ftype)atoi(argv[3]); + } + int nthread = argc > 4 ? atoi(argv[4]) : 0; const int64_t t_main_start_us = ggml_time_us(); From ea3ad7eb60cfb44526a58122e8019850f437cd1b Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Wed, 26 Apr 2023 22:03:03 +0200 Subject: [PATCH 2/6] Updating build instructions to include BLAS support (#1183) * Updated build information First update to the build instructions to include BLAS. * Update README.md * Update information about BLAS * Better BLAS explanation Adding a clearer BLAS explanation and adding a link to download the CUDA toolkit. * Better BLAS explanation * BLAS for Mac Specifying that BLAS is already supported on Macs using the Accelerate Framework. * Clarify the effect of BLAS * Windows Make instructions Added the instructions to build with Make on Windows * Fixing typo * Fix trailing whitespace --- README.md | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 509df61a1..ddbd4c8b1 100644 --- a/README.md +++ b/README.md @@ -167,15 +167,27 @@ cd llama.cpp ### Build -Note: For Windows, CMake or Zig can be used. +In order to build llama.cpp you have three different options. -1. Use `make` +- Using `make`: + - On Linux or MacOS: - ```bash - make - ``` + ```bash + make + ``` -1. Use CMake + - On Windows: + + 1. Download the latest fortran version of [w64devkit](https://github.com/seeto/w64devkit/releases). + 2. Extract `w64devkit` on your pc. + 3. Run `w64devkit.exe`. + 4. Use the `cd` command to reach the `llama.cpp` folder. + 5. From here you can run: + ```bash + make + ``` + +- Using `CMake`: ```bash mkdir build @@ -184,12 +196,71 @@ Note: For Windows, CMake or Zig can be used. cmake --build . --config Release ``` -1. Use Zig +- Using `Zig`: ```bash zig build -Drelease-fast ``` +### BLAS Build + +Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). BLAS doesn't affect the normal generation performance. There are currently three different implementations of it: + +- Accelerate Framework: + + This is only available on Mac PCs and it's enabled by default. You can just build using the normal instructions. + +- OpenBLAS: + + This provides BLAS acceleration using only the CPU. Make sure to have OpenBLAS installed on your machine. + + - Using `make`: + - On Linux: + ```bash + make LLAMA_OPENBLAS=1 + ``` + Note: In order to build on Arch Linux with OpenBLAS support enabled you must edit the Makefile adding at the end of the line 105: `-lcblas` + + - On Windows: + + 1. Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases). + 2. Download the latest version of [OpenBLAS for Windows](https://github.com/xianyi/OpenBLAS/releases). + 3. Extract `w64devkit` on your pc. + 4. From the OpenBLAS zip that you just downloaded copy `libopenblas.a`, located inside the `lib` folder, inside `w64devkit\x86_64-w64-mingw32\lib`. + 5. From the same OpenBLAS zip copy the content of the `include` folder inside `w64devkit\x86_64-w64-mingw32\include`. + 6. Run `w64devkit.exe`. + 7. Use the `cd` command to reach the `llama.cpp` folder. + 8. From here you can run: + + ```bash + make LLAMA_OPENBLAS=1 + ``` + + - Using `CMake` on Linux: + + ```bash + mkdir build + cd build + cmake .. -DLLAMA_OPENBLAS=ON + cmake --build . --config Release + ``` + +- cuBLAS + + This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). + - Using `make`: + ```bash + make LLAMA_CUBLAS=1 + ``` + - Using `CMake`: + + ```bash + mkdir build + cd build + cmake .. -DLLAMA_CUBLAS=ON + cmake --build . --config Release + ``` + ### Prepare Data & Run ```bash From 87a6f846d3e929632c45916dd08f1e2a9c72d2a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81sgeir=20Bjarni=20Ingvarsson?= Date: Wed, 26 Apr 2023 20:08:43 +0000 Subject: [PATCH 3/6] Allow setting the rng seed after initialization. (#1184) The llama_set_state_data function restores the rng state to what it was at the time llama_copy_state_data was called. But users may want to restore the state and proceed with a different seed. --- llama.cpp | 7 +++++++ llama.h | 3 +++ 2 files changed, 10 insertions(+) diff --git a/llama.cpp b/llama.cpp index 25203c9e9..8334553a5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2082,6 +2082,13 @@ int llama_get_kv_cache_token_count(struct llama_context * ctx) { #define LLAMA_MAX_RNG_STATE 64*1024 +void llama_set_rng_seed(struct llama_context * ctx, int seed) { + if (seed <= 0) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + // Returns the size of the state size_t llama_get_state_size(struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. diff --git a/llama.h b/llama.h index ab41798d8..24c48cce6 100644 --- a/llama.h +++ b/llama.h @@ -116,6 +116,9 @@ extern "C" { // Returns the number of tokens in the KV cache LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed); + // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); From 574406dc7e350ddbffaeca33bf0392b7bfeb1436 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 23:14:13 +0300 Subject: [PATCH 4/6] ggml : add Q5_0 and Q5_1 quantization (#1187) * ggml : add Q5_0 quantization (cuBLAS only) * ggml : fix Q5_0 qh -> uint32_t * ggml : fix q5_0 histogram stats * ggml : q5_0 scalar dot product * ggml : q5_0 ARM NEON dot * ggml : q5_0 more efficient ARM NEON using uint64_t masks * ggml : rename Q5_0 -> Q5_1 * ggml : adding Q5_0 mode * quantize : add Q5_0 and Q5_1 to map * ggml : AVX2 optimizations for Q5_0, Q5_1 (#1195) --------- Co-authored-by: Stephan Walter --- .gitignore | 1 + examples/quantize/quantize.cpp | 2 + ggml-cuda.cu | 85 +++++ ggml-cuda.h | 2 + ggml.c | 633 +++++++++++++++++++++++++++++++-- ggml.h | 8 +- llama.cpp | 8 + llama.h | 2 + 8 files changed, 711 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index e52d479ee..c7573bb3b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ build-em/ build-debug/ build-release/ build-static/ +build-cublas/ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index ec7f91aae..60966595e 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -10,6 +10,8 @@ static const std::map LLAMA_FTYPE_MAP = { {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2}, {"q4_3", LLAMA_FTYPE_MOSTLY_Q4_3}, + {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, + {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, }; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f104ed5ac..b1bd29b10 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -37,6 +37,23 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + __half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + __half d; // delta + __half m; // min + uint32_t qh; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + #define QK8_0 32 typedef struct { float d; // delta @@ -138,6 +155,64 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } +static __global__ void dequantize_block_q5_0(const void * vx, float * y) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = ((vi & 0xf) | vh0); + const int8_t vi1 = ((vi >> 4) | vh1); + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q5_1(const void * vx, float * y) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * pp = x[i].qs; + + const uint32_t qh = x[i].qh; + + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0xf) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; + } +} + static __global__ void dequantize_block_q8_0(const void * vx, float * y) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -174,6 +249,16 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_3<<>>(vx, y); } +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_0; + dequantize_block_q5_0<<>>(vx, y); +} + +void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK5_1; + dequantize_block_q5_1<<>>(vx, y); +} + void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK8_0; dequantize_block_q8_0<<>>(vx, y); diff --git a/ggml-cuda.h b/ggml-cuda.h index 4048ea491..ed9b44184 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -35,6 +35,8 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); #ifdef __cplusplus diff --git a/ggml.c b/ggml.c index 064510eda..03b4bd439 100644 --- a/ggml.c +++ b/ggml.c @@ -328,6 +328,19 @@ static ggml_fp16_t table_exp_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) static float table_f32_f16[1 << 16]; +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes (shl 4) +static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) }; +static const uint64_t table_b2b_i[1 << 8] = { B8(F0, 00) }; + // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. @@ -673,6 +686,23 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK5_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + #define QK8_0 32 typedef struct { float d; // delta @@ -1288,6 +1318,103 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int quantize_row_q4_3_reference(x, y, k); } +static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int l = 0; l < QK5_0; l++) { + const float v = x[i*QK5_0 + l]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int l = 0; l < QK5_0; l += 2) { + const float v0 = x[i*QK5_0 + l + 0]*id; + const float v1 = x[i*QK5_0 + l + 1]*id; + + const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f)); + const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f)); + + y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((vi0 & 0x10) >> 4) << (l + 0); + qh |= ((vi1 & 0x10) >> 4) << (l + 1); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_0 == 0); + + block_q5_0 * restrict y = vy; + + quantize_row_q5_0_reference(x, y, k); +} + +static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK5_1; l++) { + const float v = x[i*QK5_1 + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + uint32_t qh = 0; + + for (int l = 0; l < QK5_1; l += 2) { + const float v0 = (x[i*QK5_1 + l + 0] - min)*id; + const float v1 = (x[i*QK5_1 + l + 1] - min)*id; + + const uint32_t vi0 = (int) (v0 + 0.5f); + const uint32_t vi1 = (int) (v1 + 0.5f); + + y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((vi0 & 0x10) >> 4) << (l + 0); + qh |= ((vi1 & 0x10) >> 4) << (l + 1); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK5_1 == 0); + + block_q5_1 * restrict y = vy; + + quantize_row_q5_1_reference(x, y, k); +} + // reference implementation for deterministic creation of model files static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { assert(k % QK8_0 == 0); @@ -1571,7 +1698,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in const uint8x8_t v8 = vld1_u8(pp + l/2); // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F)); const uint8x8_t v1 = vshr_n_u8(v8, 4); // Convert to signed 8-bit integers @@ -1621,7 +1748,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_0; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = (vi0 - 8)*d; @@ -1687,7 +1814,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const uint8x8_t v8 = vld1_u8(pp + l/2); // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F)); const uint8x8_t v1 = vshr_n_u8(v8, 4); // Interleave and combine @@ -1729,7 +1856,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_1; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = vi0*d + m; @@ -1759,7 +1886,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_2; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = (vi0 - 8)*d; @@ -1789,7 +1916,7 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in for (int l = 0; l < QK4_3; l += 2) { const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; + const int8_t vi0 = vi & 0x0F; const int8_t vi1 = vi >> 4; const float v0 = vi0*d + m; @@ -1804,6 +1931,79 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in } } +static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + const block_q5_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; + + // extract the 5-th bit from qh + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const int8_t vi0 = (vi & 0x0F) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; + + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; + + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + + assert(!isnan(y[i*QK5_0 + l + 0])); + assert(!isnan(y[i*QK5_0 + l + 1])); + } + } +} + +static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + const block_q5_1 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + const uint8_t * restrict pp = x[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vi = pp[l/2]; + + // extract the 5-th bit from qh + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + const uint8_t vi0 = (vi & 0x0F) | vh0; + const uint8_t vi1 = (vi >> 4) | vh1; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; + + assert(!isnan(y[i*QK5_1 + l + 0])); + assert(!isnan(y[i*QK5_1 + l + 1])); + } + } +} + static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1825,6 +2025,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { @@ -1860,6 +2062,22 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_q = ggml_vec_dot_q4_3_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, }, + [GGML_TYPE_Q5_0] = { + .dequantize_row_q = dequantize_row_q5_0, + .quantize_row_q = quantize_row_q5_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q5_1] = { + .dequantize_row_q = dequantize_row_q5_1, + .quantize_row_q = quantize_row_q5_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference, + .quantize_row_q_dot = quantize_row_q8_1, + .vec_dot_q = ggml_vec_dot_q5_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + }, [GGML_TYPE_Q8_0] = { .dequantize_row_q = dequantize_row_q8_0, .quantize_row_q = quantize_row_q8_0, @@ -2496,7 +2714,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); @@ -2632,8 +2850,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * for (int j = 0; j < QK8_0/2; j++) { const uint8_t v0 = p0[j]; - const int i0 = (int8_t) (v0 & 0xf) - 8; - const int i1 = (int8_t) (v0 >> 4) - 8; + const int i0 = (int8_t) (v0 & 0x0F) - 8; + const int i1 = (int8_t) (v0 >> 4) - 8; const int i2 = p1[2*j + 0]; const int i3 = p1[2*j + 1]; @@ -2670,7 +2888,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2767,8 +2985,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * for (int j = 0; j < QK8_1/2; j++) { const uint8_t v0 = p0[j]; - const float f0 = d0*(v0 & 0xf) + m0; - const float f1 = d0*(v0 >> 4) + m0; + const float f0 = d0*(v0 & 0x0F) + m0; + const float f1 = d0*(v0 >> 4) + m0; const float f2 = d1*p1[2*j + 0]; const float f3 = d1*p1[2*j + 1]; @@ -2803,7 +3021,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t m4b = vdupq_n_u8(0x0F); const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); @@ -2914,11 +3132,11 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const uint8_t v0 = x0[j]; const uint8_t v1 = x1[j]; - const int i0_0 = (int8_t) (v0 & 0xf) - 8; - const int i1_0 = (int8_t) (v0 >> 4) - 8; + const int i0_0 = (int8_t) (v0 & 0x0F) - 8; + const int i1_0 = (int8_t) (v0 >> 4) - 8; - const int i0_1 = (int8_t) (v1 & 0xf) - 8; - const int i1_1 = (int8_t) (v1 >> 4) - 8; + const int i0_1 = (int8_t) (v1 & 0x0F) - 8; + const int i1_1 = (int8_t) (v1 >> 4) - 8; const int i2_0 = y0[2*j + 0]; const int i3_0 = y0[2*j + 1]; @@ -2966,7 +3184,7 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F))); const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); // interleave @@ -3045,10 +3263,10 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * const uint8_t v0 = x0[j]; const uint8_t v1 = x1[j]; - const int x0_0 = v0 & 0xf; + const int x0_0 = v0 & 0x0F; const int x1_0 = v0 >> 4; - const int x0_1 = v1 & 0xf; + const int x0_1 = v1 & 0x0F; const int x1_1 = v1 >> 4; const int y0_0 = y0[2*j + 0]; @@ -3067,6 +3285,273 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * #endif } +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + assert(QK8_0 == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + uint64_t tmp[4]; + + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s16b = vdupq_n_s8(0x10); + + // extract the 5th bit + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_u[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_u[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_u[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_u[(qh >> 24) ]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); + + // 4-bit -> 8-bit + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b)); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); + + // interleave + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add high bit and sub 16 + const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b); + const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b); + + // load y + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); + + const float x0d = GGML_FP16_TO_FP32(x0->d); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); +#endif + } + + *s = vaddvq_f32(sumv); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = _mm256_set_epi64x( + table_b2b_i[x[i].qh[3]], table_b2b_i[x[i].qh[2]], + table_b2b_i[x[i].qh[1]], table_b2b_i[x[i].qh[0]]); + bx = _mm256_or_si256(bx, bxhi); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#else + // scalar + float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[i].qs; + const int8_t * restrict y0 = y[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + const float d = GGML_FP16_TO_FP32(x[i].d); + + int sxy = 0; + + for (int j = 0; j < QK8_0/2; j++) { + const uint8_t v0 = x0[j]; + + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; + + const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16; + const int x1_0 = ((v0 >> 4) | x1_0h) - 16; + + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; + + sxy += x0_0*y0_0 + x1_0*y1_0; + } + + sumf += (d*sxy)*y[i].d; + } + *s = sumf; +#endif +} + +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_1; + + assert(n % QK8_1 == 0); + assert(nb % 2 == 0); + assert(QK8_1 == QK5_1); + + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + float summs = 0.0f; + + uint64_t tmp[4]; + + for (int i = 0; i < nb; ++i) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; + + summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); + + // extract the 5th bit + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_u[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_u[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_u[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_u[(qh >> 24) ]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); + + // 4-bit -> 8-bit + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F))); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); + + // interleave + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add + const int8x16_t v0lf = vorrq_s8(v0lz, qhl); + const int8x16_t v0hf = vorrq_s8(v0hz, qhh); + + // load y + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); + + const float x0d = GGML_FP16_TO_FP32(x0->d); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); +#endif + } + + *s = vaddvq_f32(sumv) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + + summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = _mm256_set_epi64x( + table_b2b_u[x[i].qh[3]], table_b2b_u[x[i].qh[2]], + table_b2b_u[x[i].qh[1]], table_b2b_u[x[i].qh[0]]); + bx = _mm256_or_si256(bx, bxhi); + + const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; +#else + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const uint8_t * restrict x0 = x[i].qs; + const int8_t * restrict y0 = y[i].qs; + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + int sxy = 0; + + for (int j = 0; j < QK8_1/2; j++) { + const uint8_t v0 = x0[j]; + + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; + + const int x0_0 = (v0 & 0x0F) | x0_0h; + const int x1_0 = (v0 >> 4) | x1_0h; + + const int y0_0 = y0[2*j + 0]; + const int y1_0 = y0[2*j + 1]; + + sxy += x0_0*y0_0 + x1_0*y1_0; + } + + sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1); + } + + *s = sumf; +#endif +} + static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_0; @@ -3409,13 +3894,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = QK4_1, [GGML_TYPE_Q4_2] = QK4_2, [GGML_TYPE_Q4_3] = QK4_3, + [GGML_TYPE_Q5_0] = QK5_0, + [GGML_TYPE_Q5_1] = QK5_1, [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_Q8_1] = QK8_1, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3424,13 +3911,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = sizeof(block_q4_1), [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_Q4_3] = sizeof(block_q4_3), + [GGML_TYPE_Q5_0] = sizeof(block_q5_0), + [GGML_TYPE_Q5_1] = sizeof(block_q5_1), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_Q8_1] = sizeof(block_q8_1), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3440,13 +3929,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = "q4_1", [GGML_TYPE_Q4_2] = "q4_2", [GGML_TYPE_Q4_3] = "q4_3", + [GGML_TYPE_Q5_0] = "q5_0", + [GGML_TYPE_Q5_1] = "q5_1", [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_Q8_1] = "q8_1", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3455,13 +3946,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_1] = true, [GGML_TYPE_Q4_2] = true, [GGML_TYPE_Q4_3] = true, + [GGML_TYPE_Q5_0] = true, + [GGML_TYPE_Q5_1] = true, [GGML_TYPE_Q8_0] = true, [GGML_TYPE_Q8_1] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 11, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -6673,6 +7166,8 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); @@ -8161,6 +8656,12 @@ static void ggml_compute_forward_mul_mat_q_f32( else if (type == GGML_TYPE_Q4_3) { dequantize_row_q_cuda = dequantize_row_q4_3_cuda; } + else if (type == GGML_TYPE_Q5_0) { + dequantize_row_q_cuda = dequantize_row_q5_0_cuda; + } + else if (type == GGML_TYPE_Q5_1) { + dequantize_row_q_cuda = dequantize_row_q5_1_cuda; + } else if (type == GGML_TYPE_Q8_0) { dequantize_row_q_cuda = dequantize_row_q8_0_cuda; } @@ -8319,6 +8820,8 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: { @@ -8549,6 +9052,8 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: { @@ -12261,7 +12766,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_0; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12284,7 +12789,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_1; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12307,7 +12812,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_2; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12330,7 +12835,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi0 = y[i].qs[l/2] & 0x0F; const uint8_t vi1 = y[i].qs[l/2] >> 4; hist[vi0]++; @@ -12342,6 +12847,66 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_3*sizeof(block_q4_3)); } +size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int j = 0; j < n; j += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0; + + quantize_row_q5_0_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + +size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + for (int j = 0; j < n; j += k) { + block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1; + + quantize_row_q5_1_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_1*sizeof(block_q5_1)); +} + size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -12390,6 +12955,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; result = ggml_quantize_q4_3(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_1: + { + GGML_ASSERT(start % QK5_1 == 0); + block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; + result = ggml_quantize_q5_1(src + start, block, n, n, hist); + } break; case GGML_TYPE_Q8_0: { GGML_ASSERT(start % QK8_0 == 0); diff --git a/ggml.h b/ggml.h index 8300a0c62..d9d3d214e 100644 --- a/ggml.h +++ b/ggml.h @@ -222,8 +222,10 @@ extern "C" { GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_2 = 4, GGML_TYPE_Q4_3 = 5, - GGML_TYPE_Q8_0 = 6, - GGML_TYPE_Q8_1 = 7, + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -833,6 +835,8 @@ extern "C" { GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); diff --git a/llama.cpp b/llama.cpp index 8334553a5..28a74b514 100644 --- a/llama.cpp +++ b/llama.cpp @@ -484,6 +484,8 @@ struct llama_file_loader { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; default: { @@ -559,6 +561,8 @@ struct llama_file_saver { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: break; default: LLAMA_ASSERT(false); @@ -850,6 +854,8 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; default: return "unknown, may not work"; } @@ -1588,6 +1594,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; default: throw format("invalid output file type %d\n", ftype); }; diff --git a/llama.h b/llama.h index 24c48cce6..17dac0689 100644 --- a/llama.h +++ b/llama.h @@ -75,6 +75,8 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params(); From f9be42add0bf3ce61814b7ede0e6d0dda9ff22c6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 23:24:42 +0300 Subject: [PATCH 5/6] readme : add quantization info --- README.md | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ddbd4c8b1..be0e49e47 100644 --- a/README.md +++ b/README.md @@ -7,31 +7,27 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ -**Warnings** - -- `Q4_2` and `Q4_3` are still in development. Do not expect any kind of backward compatibility until they are finalized - **Hot topics:** +- [New quantization methods](https://github.com/ggerganov/llama.cpp#quantization) - [Added LoRA support](https://github.com/ggerganov/llama.cpp/pull/820) - [Add GPU support to ggml](https://github.com/ggerganov/llama.cpp/discussions/915) - [Roadmap Apr 2023](https://github.com/ggerganov/llama.cpp/discussions/784) ## Description -The main goal of llama.cpp is to run the llama model using 4-bit quantization on a MacBook. +The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quantization on a MacBook - Plain C/C++ implementation without dependencies - Apple silicon first-class citizen - optimized via ARM NEON and Accelerate framework - AVX2 support for x86 architectures - Mixed F16 / F32 precision -- 4-bit quantization support +- 4-bit integer quantization support - Runs on the CPU -This was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022) - I have no idea if it works correctly. -Please do not make conclusions about the models based on the results from this implementation. -For all I know, it can be completely wrong. This project is for educational purposes. -New features will probably be added mostly through community contributions. +The original implementation of `llama.cpp` was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022). +Since then, the project has improved significantly thanks to many contributions. This project is for educational purposes and serves +as the main playground for developing new features for the [ggml](https://github.com/ggerganov/ggml) library. **Supported platforms:** @@ -294,6 +290,24 @@ As the models are currently fully loaded into memory, you will need adequate dis | 30B | 60 GB | 19.5 GB | | 65B | 120 GB | 38.5 GB | +### Quantization + +Several quantization methods are supported. They differ in the resulting model disk size and inference speed. + +Model | F16 | Q4_0 | Q4_1 | Q4_2 | Q4_3 | Q5_0 | Q5_1 | Q8_0 +-- | -- | -- | -- | -- | -- | -- | -- | -- +7B (ppl) | 5.9565 | 6.2103 | 6.1286 | 6.1698 | 6.0617 | 6.0139 | 5.9934 | 5.9571 +7B (size) | 13.0G | 4.0G | 4.8G | 4.0G | 4.8G | 4.4G | 4.8G | 7.1G +7B (ms/tok @ 4th) | 128 | 56 | 61 | 84 | 91 | 91 | 95 | 75 +7B (ms/tok @ 8th) | 128 | 47 | 55 | 48 | 53 | 53 | 59 | 75 +7B (bpw) | 16.0 | 5.0 | 6.0 | 5.0 | 6.0 | 5.5 | 6.0 | 9.0 +-- | -- | -- | -- | -- | -- | -- | -- | -- +13B (ppl) | 5.2455 | 5.3748 | 5.3471 | 5.3433 | 5.3234 | 5.2768 | 5.2582 | 5.2458 +13B (size) | 25.0G | 7.6G | 9.1G | 7.6G | 9.1G | 8.4G | 9.1G | 14G +13B (ms/tok @ 4th) | 239 | 104 | 113 | 160 | 175 | 176 | 185 | 141 +13B (ms/tok @ 8th) | 240 | 85 | 99 | 97 | 114 | 108 | 117 | 147 +13B (bpw) | 16.0 | 5.0 | 6.0 | 5.0 | 6.0 | 5.5 | 6.0 | 9.0 + ### Interactive mode If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter. From 0b2da20538d01926b77ea237dd1c930c4d20b686 Mon Sep 17 00:00:00 2001 From: Stephan Walter Date: Wed, 26 Apr 2023 20:26:42 +0000 Subject: [PATCH 6/6] ggml : slightly faster AVX2 implementation for Q5 (#1197) --- ggml.c | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/ggml.c b/ggml.c index 03b4bd439..3422a9448 100644 --- a/ggml.c +++ b/ggml.c @@ -328,6 +328,7 @@ static ggml_fp16_t table_exp_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) static float table_f32_f16[1 << 16]; +#if defined(__ARM_NEON) #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) @@ -339,7 +340,7 @@ static float table_f32_f16[1 << 16]; // precomputed tables for expanding 8bits to 8 bytes (shl 4) static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) }; -static const uint64_t table_b2b_i[1 << 8] = { B8(F0, 00) }; +#endif // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. @@ -490,6 +491,19 @@ static inline int hsum_i32_4(const __m128i a) { } #if __AVX2__ || __AVX512F__ +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) @@ -3367,9 +3381,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = _mm256_set_epi64x( - table_b2b_i[x[i].qh[3]], table_b2b_i[x[i].qh[2]], - table_b2b_i[x[i].qh[1]], table_b2b_i[x[i].qh[0]]); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); bx = _mm256_or_si256(bx, bxhi); __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -3501,9 +3514,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1); __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = _mm256_set_epi64x( - table_b2b_u[x[i].qh[3]], table_b2b_u[x[i].qh[2]], - table_b2b_u[x[i].qh[1]], table_b2b_u[x[i].qh[0]]); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); bx = _mm256_or_si256(bx, bxhi); const __m256 dy = _mm256_broadcast_ss(&y[i].d);