From c8e0d7efeb7634ecc2e9832e879ab9fca4510e71 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sun, 18 Feb 2024 00:17:07 +0000 Subject: [PATCH 01/11] flake.lock: Update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flake lock file updates: • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/f8e2ebd66d097614d51a56a755450d4ae1632df1' (2024-02-07) → 'github:NixOS/nixpkgs/5863c27340ba4de8f83e7e3c023b9599c3cb3c80' (2024-02-16) --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 239d0686c..47d6448b5 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1707268954, - "narHash": "sha256-2en1kvde3cJVc3ZnTy8QeD2oKcseLFjYPLKhIGDanQ0=", + "lastModified": 1708118438, + "narHash": "sha256-kk9/0nuVgA220FcqH/D2xaN6uGyHp/zoxPNUmPCMmEE=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "f8e2ebd66d097614d51a56a755450d4ae1632df1", + "rev": "5863c27340ba4de8f83e7e3c023b9599c3cb3c80", "type": "github" }, "original": { From bd2d4e393b2b7d2a1b2e201058e26017c9728ead Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Sun, 18 Feb 2024 18:16:55 +0200 Subject: [PATCH 02/11] 1.5 bit quantization (#5453) * iq1_s: WIP basics * iq1_s: CUDA is working * iq1_s: scalar CPU dot product * iq1_s: WIP AVX2 dot product - something is not right * Fix tests * Fix shadow warnings * Fix after merge with latest master * iq1_s: AVX2 finally works * iq1_s: ARM_NEON dot product. Works, but not very fast * iq1_s: better grid * iq1_s: use IQ2_XXS for attn_output At a cost of 0.04 extra bpw this gives a big improvement in PPL. * iq1_s: Metal basics Dequantize works, but not dot product * iq1_s: Metal works, but quite slow As usual, Apple Silicon does not like the code I write. * iq1_s: Tests * iq1_s: slightly faster dot product --------- Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 6 +- ggml-backend.c | 2 +- ggml-cuda.cu | 224 ++++++++++- ggml-metal.m | 29 +- ggml-metal.metal | 337 +++++++++++++++++ ggml-quants.c | 657 +++++++++++++++++++++++++++++++-- ggml-quants.h | 14 +- ggml.c | 44 ++- ggml.h | 2 + llama.cpp | 16 +- llama.h | 1 + tests/test-backend-ops.cpp | 2 +- 12 files changed, 1286 insertions(+), 48 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 4a5c504e3..ea7ba50c9 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -23,6 +23,7 @@ static const std::vector QUANT_OPTIONS = { { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, { "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, + { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, @@ -287,9 +288,10 @@ int main(int argc, char ** argv) { } } - if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) && imatrix_data.empty()) { + if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || + params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) { fprintf(stderr, "\n===============================================================================================\n"); - fprintf(stderr, "Please do not use IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); + fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); fprintf(stderr, "===============================================================================================\n\n\n"); return 1; } diff --git a/ggml-backend.c b/ggml-backend.c index 66e8c293a..5076d9e5e 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -756,7 +756,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { switch (op->op) { case GGML_OP_CPY: - return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float + return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float case GGML_OP_MUL_MAT: return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; default: diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5fd8a87e4..933ebbc4e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -517,6 +517,15 @@ typedef struct { } block_iq3_xxs; static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); +#define QR1_S 8 +#define QI1_S (QK_K / (4*QR1_S)) +typedef struct { + half d; + uint8_t qs[QK_K/8]; + uint8_t scales[QK_K/16]; +} block_iq1_s; +static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); + #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -1681,6 +1690,137 @@ static const __device__ uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; +static const __device__ uint64_t iq1s_grid[512] = { + 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, + 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, + 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, + 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, + 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, + 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, + 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, + 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, + 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, + 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, + 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, + 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, + 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, + 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, + 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, + 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, + 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, + 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, + 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, + 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, + 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, + 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, + 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, + 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, + 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, + 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, + 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, + 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, + 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, + 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, + 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, + 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, + 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, + 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, + 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, + 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, + 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, + 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, + 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, + 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, + 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, + 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, + 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, + 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, + 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, + 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, + 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, + 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, + 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, + 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, + 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, + 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, + 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, + 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, + 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, + 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, + 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, + 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, + 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, + 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, + 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, + 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, + 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, + 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, + 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, + 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, + 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, + 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, + 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, + 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, + 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, + 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, + 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, + 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, + 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, + 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, + 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, + 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, + 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, + 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, + 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, + 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, + 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, + 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, + 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, + 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, + 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, + 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, + 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, + 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, + 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, + 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, + 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, + 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, + 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, + 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, + 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, + 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, + 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, + 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, + 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, + 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, + 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, + 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, + 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, + 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, + 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, + 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, + 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, + 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, + 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, + 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, + 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, + 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, + 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, + 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, + 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, + 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, + 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, + 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, + 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, + 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, + 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, + 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, + 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, + 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, + 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, + 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, +}; + static const __device__ uint8_t ksigns_iq2xs[128] = { 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, @@ -1823,6 +1963,29 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } +template +static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq1_s * x = (const block_iq1_s *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const int i8 = 4*ib+il; + uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); + const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); + const float d = (float)x[i].d * (2*(h & 7) + 1); + for (int j = 0; j < 8; ++j) y[j] = d * grid[j]; +#else + assert(false); +#endif + +} + + static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -4522,6 +4685,49 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #endif } +static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if QK_K == 256 + const block_iq1_s * bq1 = (const block_iq1_s *) vbq; + + const int ib32 = iqs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + const uint8_t h1 = bq1->scales[2*ib32+0]; + const uint8_t h2 = bq1->scales[2*ib32+1]; +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const int * q8 = (const int *)bq8_1[ib32].qs; + const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); + const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); + const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); + const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); + for (int j = 0; j < 2; ++j) { + sumi1 = __dp4a(q8[j+0], grid1[j], sumi1); + sumi2 = __dp4a(q8[j+2], grid2[j], sumi2); + sumi3 = __dp4a(q8[j+4], grid3[j], sumi3); + sumi4 = __dp4a(q8[j+6], grid4[j], sumi4); + } +#else + const int8_t * q8 = bq8_1[ib32].qs; + const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); + const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); + const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); + const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j+ 0] * grid1[j]; + sumi2 += q8[j+ 8] * grid2[j]; + sumi3 += q8[j+16] * grid3[j]; + sumi4 += q8[j+24] * grid4[j]; + } +#endif + const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds); + return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + + sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); +#else + assert(false); + return 0.f; +#endif +} + template static __device__ __forceinline__ void mul_mat_q( @@ -6561,6 +6767,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, dequantize_block_iq3_xxs<<>>(vx, y); } +template +static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq1_s<<>>(vx, y); +} + template static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; @@ -6600,6 +6812,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ3_XXS: return dequantize_row_iq3_xxs_cuda; + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_cuda; case GGML_TYPE_F32: return convert_unary_cuda; default: @@ -6635,6 +6849,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ3_XXS: return dequantize_row_iq3_xxs_cuda; + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_cuda; case GGML_TYPE_F16: return convert_unary_cuda; default: @@ -8378,6 +8594,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_RDNA2 ? 128 : 64; default: GGML_ASSERT(false); @@ -8401,6 +8618,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: return 64; @@ -8498,6 +8716,10 @@ static void ggml_cuda_op_mul_mat_vec_q( mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_q_cuda + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; default: GGML_ASSERT(false); break; @@ -11214,7 +11436,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) { + if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) { if (b->ne[1] == 1 && ggml_nrows(b) > 1) { return false; } diff --git a/ggml-metal.m b/ggml-metal.m index c0848a293..f3c1fff8f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -61,6 +61,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, @@ -83,6 +84,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, @@ -101,6 +103,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, @@ -116,6 +119,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, @@ -131,6 +135,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_ROPE_F32, GGML_METAL_KERNEL_TYPE_ROPE_F16, GGML_METAL_KERNEL_TYPE_ALIBI_F32, @@ -433,6 +438,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); @@ -455,6 +461,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); @@ -473,6 +480,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); @@ -488,6 +496,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); @@ -503,6 +512,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); @@ -1318,6 +1328,7 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } @@ -1452,6 +1463,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -1486,7 +1503,7 @@ static bool ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { + src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -1592,6 +1609,7 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } @@ -1729,6 +1747,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -1779,7 +1803,7 @@ static bool ggml_metal_graph_compute( if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || - src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) { + src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) { @@ -1833,6 +1857,7 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 09ebcc9e3..a00962111 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2525,6 +2525,13 @@ typedef struct { } block_iq3_xxs; // 98 bytes / block for QK_K = 256, so 3.0625 bpw +typedef struct { + half d; + uint8_t qs[QK_K/8]; + uint8_t scales[QK_K/16]; +} block_iq1_s; + + //====================================== dot products ========================= void kernel_mul_mv_q2_K_f32_impl( @@ -3782,6 +3789,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; +#define NGRID_IQ1S 512 +constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = { + 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, + 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, + 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, + 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, + 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, + 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, + 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, + 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, + 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, + 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, + 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, + 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, + 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, + 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, + 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, + 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, + 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, + 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, + 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, + 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, + 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, + 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, + 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, + 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, + 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, + 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, + 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, + 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, + 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, + 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, + 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, + 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, + 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, + 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, + 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, + 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, + 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, + 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, + 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, + 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, + 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, + 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, + 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, + 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, + 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, + 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, + 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, + 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, + 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, + 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, + 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, + 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, + 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, + 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, + 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, + 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, + 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, + 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, + 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, + 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, + 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, + 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, + 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, + 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, + 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, + 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, + 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, + 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, + 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, + 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, + 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, + 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, + 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, + 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, + 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, + 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, + 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, + 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, + 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, + 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, + 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, + 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, + 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, + 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, + 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, + 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, + 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, + 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, + 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, + 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, + 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, + 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, + 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, + 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, + 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, + 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, + 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, + 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, + 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, + 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, + 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, + 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, + 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, + 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, + 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, + 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, + 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, + 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, + 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, + 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, + 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, + 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, + 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, + 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, + 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, + 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, + 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, + 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, + 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, + 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, + 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, + 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, + 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, + 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, + 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, + 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, + 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, + 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, +}; constexpr constant static uint8_t ksigns_iq2xs[128] = { 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, @@ -4208,6 +4346,123 @@ kernel void kernel_mul_mv_iq3_xxs_f32( kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } +void kernel_mul_mv_iq1_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + +#if QK_K == 256 + const int ix = tiisg/2; + const int il = tiisg%2; + + device const float * y4 = y + 32 * ix + 16 * il; + + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { + + for (int i = 0; i < 16; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib + 2 * il; + device const uint8_t * sc = xr->scales + 2 * ib + il; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); + constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); + + float2 sum = {0}; + for (int j = 0; j < 8; ++j) { + sum[0] += yl[j+ 0] * grid1[j]; + sum[1] += yl[j+ 8] * grid2[j]; + } + sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1)); + + dh += nb*sizeof(block_iq1_s)/2; + qs += nb*sizeof(block_iq1_s); + sc += nb*sizeof(block_iq1_s); + } + + y4 += 16 * 32; + } +#else + // TODO +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= @@ -4553,6 +4808,22 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x } } +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + device const uint8_t * qs = xb->qs + 2*il; + device const uint8_t * sc = xb->scales + il; + const float dl1 = d * (2*(sc[0] & 7) + 1); + const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1); + constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); + constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl1 * grid1[i]; + reg[i/4+2][i%4] = dl2 * grid2[i]; + } +} + template kernel void kernel_get_rows( device const void * src0, @@ -5095,6 +5366,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; // // matrix-matrix multiplication @@ -5134,6 +5406,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -5185,6 +5458,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; // // matrix-vector multiplication @@ -6152,3 +6426,66 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32( tiisg, sgitg); } + +[[host_name("kernel_mul_mv_id_iq1_s_f32")]] +kernel void kernel_mul_mv_id_iq1_s_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq1_s_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} diff --git a/ggml-quants.c b/ggml-quants.c index f44377f45..48f5294e1 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3480,6 +3480,139 @@ static const uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; +#define NGRID_IQ2XXS 512 +static const uint64_t iq1s_grid[NGRID_IQ2XXS] = { + 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, + 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, + 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, + 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, + 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, + 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, + 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, + 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, + 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, + 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, + 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, + 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, + 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, + 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, + 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, + 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, + 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, + 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, + 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, + 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, + 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, + 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, + 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, + 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, + 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, + 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, + 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, + 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, + 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, + 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, + 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, + 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, + 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, + 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, + 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, + 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, + 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, + 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, + 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, + 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, + 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, + 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, + 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, + 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, + 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, + 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, + 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, + 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, + 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, + 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, + 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, + 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, + 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, + 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, + 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, + 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, + 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, + 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, + 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, + 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, + 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, + 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, + 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, + 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, + 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, + 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, + 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, + 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, + 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, + 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, + 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, + 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, + 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, + 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, + 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, + 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, + 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, + 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, + 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, + 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, + 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, + 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, + 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, + 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, + 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, + 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, + 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, + 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, + 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, + 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, + 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, + 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, + 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, + 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, + 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, + 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, + 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, + 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, + 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, + 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, + 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, + 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, + 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, + 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, + 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, + 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, + 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, + 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, + 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, + 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, + 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, + 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, + 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, + 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, + 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, + 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, + 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, + 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, + 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, + 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, + 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, + 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, + 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, + 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, + 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, + 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, + 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, + 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, + +}; + static const uint8_t ksigns_iq2xs[128] = { 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, @@ -3578,6 +3711,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y } } +// ====================== 1.5625 bpw (de)-quantization + +void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + float db[4]; + uint16_t idx[4]; + //const int8_t * grid[4]; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * sc = x[i].scales; + const uint8_t * qs = x[i].qs; + + for (int i8 = 0; i8 < QK_K/8; i8 += 4) { + idx[0] = qs[0] | ((sc[0] & 0x08) << 5); + idx[1] = qs[1] | ((sc[0] & 0x80) << 1); + idx[2] = qs[2] | ((sc[1] & 0x08) << 5); + idx[3] = qs[3] | ((sc[1] & 0x80) << 1); + //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); + //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); + //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5))); + //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1))); + db[0] = d * (2*(sc[0] & 7) + 1); + db[1] = d * (2*((sc[0] >> 4) & 7) + 1); + db[2] = d * (2*(sc[1] & 7) + 1); + db[3] = d * (2*((sc[1] >> 4) & 7) + 1); + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + //y[j] = db[l] * grid[l][j]; + y[j] = db[l] * grid[j]; + } + y += 8; + } + qs += 4; + sc += 2; + } + } +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { @@ -3679,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) { } #endif -void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -3690,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r assert(nrc == 1); #endif UNUSED(nrc); - UNUSED(bx); - UNUSED(by); + UNUSED(bbx); + UNUSED(bby); UNUSED(bs); const block_q4_0 * restrict x = vx; @@ -4046,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { const int qk = QK8_1; const int nb = n / qk; @@ -4057,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r assert(nrc == 1); #endif UNUSED(nrc); - UNUSED(bx); - UNUSED(by); + UNUSED(bbx); + UNUSED(bby); UNUSED(bs); const block_q4_1 * restrict x = vx; @@ -4264,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -4272,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r assert(qk == QK5_0); assert(nrc == 1); UNUSED(nrc); - UNUSED(bx); - UNUSED(by); + UNUSED(bbx); + UNUSED(bby); UNUSED(bs); const block_q5_0 * restrict x = vx; @@ -4555,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { const int qk = QK8_1; const int nb = n / qk; @@ -4563,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r assert(qk == QK5_1); assert(nrc == 1); UNUSED(nrc); - UNUSED(bx); - UNUSED(by); + UNUSED(bbx); + UNUSED(bby); UNUSED(bs); const block_q5_1 * restrict x = vx; @@ -4859,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r #endif } -void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -4870,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r assert(nrc == 1); #endif UNUSED(nrc); - UNUSED(bx); - UNUSED(by); + UNUSED(bbx); + UNUSED(bby); UNUSED(bs); const block_q8_0 * restrict x = vx; @@ -9107,6 +9283,178 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void #endif } +#ifdef __AVX2__ +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} +#endif + +void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + + const uint8x16_t m8 = vdupq_n_u8(0x08); + const uint8x16_t m7 = vdupq_n_u8(0x07); + const uint8x16_t m1 = vdupq_n_u8(0x01); + const int32x4_t vzero = vdupq_n_s32(0); + + uint16_t gindex[8]; + uint16x8x2_t vindex; + int8x16x4_t q1b; + int8x16x4_t q8b; + uint16x8x4_t scales; + int32x4x2_t sumi; + int32x4x2_t dotq; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * sc = x[i].scales; + + sumi.val[0] = sumi.val[1] = vzero; + + for (int i128 = 0; i128 < QK_K/128; ++i128) { + const uint8x16_t ql = vld1q_u8(qs); qs += 16; + const uint8x8_t tm1 = vld1_u8 (sc); sc += 8; + const uint8x8_t tm2 = vshr_n_u8(tm1, 4); + const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2)); + const uint8x16_t hbit = vandq_u8(qh, m8); + vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5)); + vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5)); + const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1); + scales.val[0] = vmovl_u8(vget_low_u8 (scales8)); + scales.val[1] = vmovl_u8(vget_high_u8 (scales8)); + + for (int l = 0; l < 2; ++l) { + vst1q_u16(gindex+0, vindex.val[l]); + q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1]))); + q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3]))); + q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5]))); + q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7]))); + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1])); + dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3])); + + sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l])))); + sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l])))); + } + } + + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1])); + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m128i m8 = _mm_set1_epi8(0x08); + const __m128i m7 = _mm_set1_epi8(0x07); + const __m128i m1 = _mm_set1_epi8(0x01); + const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); + const __m128i shuffle_s[4] = { + _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000), + _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404), + _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808), + _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c) + }; + + uint64_t aux64; + + __m256i v_gindex; + const uint16_t * gindex = (const uint16_t *)&v_gindex; + + __m256 accum = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * sc = x[i].scales; + + __m256i sumi = _mm256_setzero_si256(); + for (int i128 = 0; i128 < QK_K/128; ++i128) { + const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16; + memcpy(&aux64, sc, 8); sc += 8; + const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h); + const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8)); + v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5)); + const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1); + + for (int i32 = 0; i32 < 4; ++i32) { + const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]], + iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]); + const __m256i dot = mul_add_epi8(q1b, q8b); + const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32])); + const __m256i p = _mm256_madd_epi16(s16, dot); + sumi = _mm256_add_epi32(sumi, p); + } + + } + + accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum); + + } + + *s = hsum_float_8(accum); + +#else + + int db[4]; + uint16_t idx[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * sc = x[i].scales; + + int sumi = 0; + for (int i32 = 0; i32 < QK_K/32; ++i32) { + idx[0] = qs[0] | ((sc[0] & 0x08) << 5); + idx[1] = qs[1] | ((sc[0] & 0x80) << 1); + idx[2] = qs[2] | ((sc[1] & 0x08) << 5); + idx[3] = qs[3] | ((sc[1] & 0x80) << 1); + db[0] = (2*(sc[0] & 7) + 1); + db[1] = (2*((sc[0] >> 4) & 7) + 1); + db[2] = (2*(sc[1] & 7) + 1); + db[3] = (2*((sc[1] >> 4) & 7) + 1); + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + int suml = 0; + for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j]; + sumi += db[l] * suml; + q8 += 8; + } + qs += 4; + sc += 2; + } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi; + } + + *s = sumf; + +#endif + +} + // ================================ IQ2 quantization ============================================= typedef struct { @@ -9115,14 +9463,22 @@ typedef struct { uint16_t * neighbours; } iq2_entry_t; -static iq2_entry_t iq2_data[2] = { +static iq2_entry_t iq2_data[3] = { + {NULL, NULL, NULL}, {NULL, NULL, NULL}, {NULL, NULL, NULL}, }; -static inline int iq2_data_index(int grid_size) { - GGML_ASSERT(grid_size == 256 || grid_size == 512); - return grid_size == 256 ? 0 : 1; +static inline int iq2_data_index(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S); + return type == GGML_TYPE_IQ2_XXS ? 0 : + type == GGML_TYPE_IQ2_XS ? 1 : 2; +} + +static inline int iq2_grid_size(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S); + return type == GGML_TYPE_IQ2_XXS ? 256 : + type == GGML_TYPE_IQ2_XS ? 512 : 512; } static int iq2_compare_func(const void * left, const void * right) { @@ -9131,12 +9487,13 @@ static int iq2_compare_func(const void * left, const void * right) { return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; } -void iq2xs_init_impl(int grid_size) { - const int gindex = iq2_data_index(grid_size); +void iq2xs_init_impl(enum ggml_type type) { + const int gindex = iq2_data_index(type); + const int grid_size = iq2_grid_size(type); if (iq2_data[gindex].grid) { return; } - static const uint16_t kgrid_256[256] = { + static const uint16_t kgrid_2bit_256[256] = { 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97, 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642, 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288, @@ -9154,7 +9511,7 @@ void iq2xs_init_impl(int grid_size) { 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142, 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268, }; - static const uint16_t kgrid_512[512] = { + static const uint16_t kgrid_2bit_512[512] = { 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257, 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340, @@ -9188,9 +9545,45 @@ void iq2xs_init_impl(int grid_size) { 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048, 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690, }; + static const uint16_t kgrid_1bit_512[512] = { + 10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545, + 553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444, + 1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440, + 2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422, + 4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397, + 5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769, + 5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788, + 6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794, + 9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272, + 10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665, + 16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685, + 17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529, + 18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517, + 20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872, + 20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653, + 21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842, + 21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913, + 21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608, + 22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072, + 23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110, + 25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937, + 25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885, + 26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808, + 32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320, + 33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918, + 34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125, + 37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973, + 38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485, + 38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497, + 39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514, + 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512, + 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680, + }; + const int kmap_size = 43692; - const int nwant = 2; - const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2; + const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 : + type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : kgrid_1bit_512; uint64_t * kgrid_q2xs; int * kmap_q2xs; uint16_t * kneighbors_q2xs; @@ -9286,9 +9679,9 @@ void iq2xs_init_impl(int grid_size) { free(dist2); } -void iq2xs_free_impl(int grid_size) { - GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024); - const int gindex = iq2_data_index(grid_size); +void iq2xs_free_impl(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S); + const int gindex = iq2_data_index(type); if (iq2_data[gindex].grid) { free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL; free(iq2_data[gindex].map); iq2_data[gindex].map = NULL; @@ -9322,7 +9715,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { - const int gindex = iq2_data_index(256); + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS); const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; const int * kmap_q2xs = iq2_data[gindex].map; @@ -9495,7 +9888,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { - const int gindex = iq2_data_index(512); + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; const int * kmap_q2xs = iq2_data[gindex].map; @@ -10132,3 +10525,207 @@ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * re assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(x, y, k, NULL); } + +// =================================== 1.5 bpw =================================================== + +static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = 0; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale * sumqx; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = (grid_i[j] - 3)/2; + sumqx += w*q*xval[j]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale*sumqx; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result. + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + +static int iq1_sort_helper(const void * left, const void * right) { + const float * l = left; + const float * r = right; + return *l < *r ? -1 : *l > *r ? 1 : 0; +} + +static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { + + const int gindex = iq2_data_index(GGML_TYPE_IQ1_S); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int nbl = n/256; + + block_iq1_s * y = vy; + + float scales[QK_K/8]; + float weight[8]; + int8_t L[8]; + float sumx[9]; + float sumw[9]; + float pairs[16]; + int * idx = (int *)(pairs + 1); + uint8_t hbit[QK_K/8]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(y[ibl].qs, 0, QK_K/8); + memset(y[ibl].scales, 0, QK_K/16); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; + + for (int ib = 0; ib < QK_K/8; ++ib) { + const float * xb = xbl + 8*ib; + const float * qw = quant_weights + QK_K*ibl + 8*ib; + for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + float max = fabsf(xb[0]); + for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i])); + if (!max) { + scales[ib] = 0; + memset(L, 1, 8); + continue; + } + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + for (int j = 0; j < 8; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper); + { + sumx[0] = sumw[0] = 0; + for (int j = 0; j < 8; ++j) { + int i = idx[2*j]; + sumx[j+1] = sumx[j] + weight[i]*xb[i]; + sumw[j+1] = sumw[j] + weight[i]; + } + } + float best_score = 0, scale = max; + int besti1 = 0, besti2 = 0; + for (int i1 = 0; i1 <= 8; ++i1) { + for (int i2 = i1; i2 <= 8; ++i2) { + float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]); + float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]); + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; + } + } + } + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < 8; ++j) L[j] = 2 - L[j]; + scale = -scale; + } + // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring + // grid point that minimizes SSD. + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS); + GGML_ASSERT(grid_index >= 0); + } + y[ibl].qs[ib] = grid_index & 255; + hbit[ib] = grid_index >> 8; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/8); + continue; + } + + float d = max_scale/15; + y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed. + float id = 1/d; + for (int ib = 0; ib < QK_K/8; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(7, l)); + if (hbit[ib]) l |= 8; + y[ibl].scales[ib/2] |= (l << 4*(ib%2)); + } + } +} + +size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq1_s); + } + return nrow * nblock * sizeof(block_iq1_s); +} diff --git a/ggml-quants.h b/ggml-quants.h index 68f09b1e1..ad381cfab 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -191,6 +191,13 @@ typedef struct { } block_iq3_xxs; static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); +typedef struct { + ggml_fp16_t d; + uint8_t qs[QK_K/8]; + uint8_t scales[QK_K/16]; +} block_iq1_s; +static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); + #ifdef __cplusplus extern "C" { #endif @@ -243,6 +250,7 @@ void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRI void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -259,6 +267,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") @@ -266,6 +275,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); @@ -276,8 +286,8 @@ size_t quantize_q4_1 (const float * src, void * dst, int nrows, int n_per_row, size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); -void iq2xs_init_impl(int grid_size); -void iq2xs_free_impl(int grid_size); +void iq2xs_init_impl(enum ggml_type type); +void iq2xs_free_impl(enum ggml_type type); void iq3xs_init_impl(int grid_size); void iq3xs_free_impl(int grid_size); diff --git a/ggml.c b/ggml.c index e94024c62..aefcda6d4 100644 --- a/ggml.c +++ b/ggml.c @@ -673,6 +673,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ1_S] = { + .type_name = "iq1_s", + .blck_size = QK_K, + .type_size = sizeof(block_iq1_s), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_s, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = ggml_vec_dot_iq1_s_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -2267,6 +2279,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; + case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7677,6 +7690,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -7944,6 +7958,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -8064,6 +8079,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: default: { GGML_ASSERT(false); @@ -10830,6 +10846,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: { ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); } break; @@ -11010,6 +11027,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: default: { GGML_ASSERT(false); @@ -11207,6 +11225,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11880,6 +11899,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -11957,6 +11977,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -19136,8 +19157,9 @@ void ggml_quantize_init(enum ggml_type type) { ggml_critical_section_start(); switch (type) { - case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break; - case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break; case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; default: // nothing break; @@ -19149,8 +19171,10 @@ void ggml_quantize_init(enum ggml_type type) { void ggml_quantize_free(void) { ggml_critical_section_start(); - iq2xs_free_impl(256); - iq2xs_free_impl(512); + iq2xs_free_impl(GGML_TYPE_IQ2_XXS); + iq2xs_free_impl(GGML_TYPE_IQ2_XS); + iq2xs_free_impl(GGML_TYPE_IQ1_S); + iq3xs_free_impl(256); ggml_critical_section_end(); } @@ -19285,7 +19309,8 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * bool ggml_quantize_requires_imatrix(enum ggml_type type) { return type == GGML_TYPE_IQ2_XXS || - type == GGML_TYPE_IQ2_XS; + type == GGML_TYPE_IQ2_XS || + type == GGML_TYPE_IQ1_S; } size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, @@ -19410,6 +19435,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); GGML_ASSERT(result == row_size * nrows); } break; + case GGML_TYPE_IQ1_S: + { + GGML_ASSERT(start % QK_K == 0); + GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = ggml_row_size(type, n_per_row); + result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + GGML_ASSERT(result == row_size * nrows); + } break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); diff --git a/ggml.h b/ggml.h index 6c1956772..004d09c70 100644 --- a/ggml.h +++ b/ggml.h @@ -354,6 +354,7 @@ extern "C" { GGML_TYPE_IQ2_XXS = 16, GGML_TYPE_IQ2_XS = 17, GGML_TYPE_IQ3_XXS = 18, + GGML_TYPE_IQ1_S = 19, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -391,6 +392,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors }; // available tensor operations: diff --git a/llama.cpp b/llama.cpp index 6ac9caa95..5cfebb3b1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2526,6 +2526,7 @@ struct llama_model_loader { case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; + case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -2875,6 +2876,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; case LLAMA_FTYPE_MOSTLY_Q3_K_XS:return "Q3_K - Extra small"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; default: return "unknown, may not work"; } @@ -10312,20 +10314,20 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) { new_type = GGML_TYPE_Q8_0; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { new_type = GGML_TYPE_Q5_K; } else if (new_type != GGML_TYPE_Q8_0) { new_type = GGML_TYPE_Q6_K; } } else if (name == "token_embd.weight") { - if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { new_type = GGML_TYPE_Q2_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_Q4_K; } - } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) { + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { if (name.find("attn_v.weight") != std::string::npos) { if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K; else new_type = GGML_TYPE_Q2_K; @@ -10335,6 +10337,9 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (qs.i_ffn_down < qs.n_ffn_down/8) new_type = GGML_TYPE_Q2_K; ++qs.i_ffn_down; } + else if (name.find("attn_output.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) new_type = GGML_TYPE_IQ2_XXS; + } } else if (name.find("attn_v.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; @@ -10468,7 +10473,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || - new_type == GGML_TYPE_IQ3_XXS) { + new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -10483,6 +10488,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: case GGML_TYPE_Q2_K: new_type = GGML_TYPE_Q4_0; break; case GGML_TYPE_Q3_K: new_type = GGML_TYPE_Q4_1; break; case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; @@ -10525,6 +10531,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_XXS: quantized_type = GGML_TYPE_IQ2_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS: quantized_type = GGML_TYPE_IQ2_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S ; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -10698,6 +10705,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } if ((new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS || + new_type == GGML_TYPE_IQ1_S || (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { LLAMA_LOG_ERROR("\n\n============================================================\n"); LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); diff --git a/llama.h b/llama.h index f4ec6ea63..5a97abcc9 100644 --- a/llama.h +++ b/llama.h @@ -100,6 +100,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 30a7d1f5a..ef37c5af2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1917,7 +1917,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, - GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, }; // unary ops From fc0c8d286a533363a9a663510b62af85ffad58b3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Sun, 18 Feb 2024 17:19:23 +0100 Subject: [PATCH 03/11] llava : update surgery script to not remove tensors (#5536) This commit updates the surgery script to not remove the tensors from the model file. For this to work the `--skip-unknown` flag is added as an argument to the convert.py script in README.md. The motivation for this change is that the surgery script currently removes the projector tensors from the model file. If the model was checked out from a repository, the model file will have been updated and have to be checked out again to reset this effect. If this can be avoided I think it would be preferable. I did not perform this change for BakLLaVA models as I am not sure how that part works. --- examples/llava/README.md | 2 +- examples/llava/llava-surgery.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/llava/README.md b/examples/llava/README.md index 57eb42932..e42db6e5a 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -53,7 +53,7 @@ python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-pa 5. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF: ```sh -python ./convert.py ../llava-v1.5-7b +python ./convert.py ../llava-v1.5-7b --skip-unknown ``` Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory. diff --git a/examples/llava/llava-surgery.py b/examples/llava/llava-surgery.py index 0a61efdfe..8b7a62fba 100644 --- a/examples/llava/llava-surgery.py +++ b/examples/llava/llava-surgery.py @@ -19,10 +19,6 @@ mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_project projector = {name: checkpoint[name].float() for name in mm_tensors} torch.save(projector, f"{args.model}/llava.projector") -# remove these tensors from the checkpoint and save it again -for name in mm_tensors: - del checkpoint[name] - # BakLLaVA models contain CLIP tensors in it clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")] if len(clip_tensors) > 0: @@ -39,7 +35,7 @@ if len(clip_tensors) > 0: f.write("{}\n") -torch.save(checkpoint, path) + torch.save(checkpoint, path) print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") From 5d3de51f972055702a1859186fe7acb8f0b43dc4 Mon Sep 17 00:00:00 2001 From: Herman Semenov Date: Sun, 18 Feb 2024 16:20:12 +0000 Subject: [PATCH 04/11] ggml, common, examples, tests : fixed type arguments in printf (#5528) --- common/common.cpp | 4 +- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched/batched.cpp | 2 +- .../convert-llama2c-to-ggml.cpp | 38 +++++++++---------- examples/perplexity/perplexity.cpp | 2 +- .../train-text-from-scratch.cpp | 14 +++---- ggml.c | 4 +- tests/test-grammar-parser.cpp | 20 +++++----- tests/test-llama-grammar.cpp | 4 +- 9 files changed, 45 insertions(+), 45 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3a92d3797..9ffc3951f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1741,7 +1741,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); + fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); @@ -1750,7 +1750,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 55dfd9784..b4b8a38e1 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -159,7 +159,7 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("\n"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index eab636692..9be7eb56b 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -92,7 +92,7 @@ int main(int argc, char ** argv) { const int n_ctx = llama_n_ctx(ctx); - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); // make sure the KV cache is big enough to hold all the prompt and generated tokens if (n_kv_req > n_ctx) { diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 4d41e1779..8209dcb64 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -325,14 +325,14 @@ struct train_params { }; static void print_params(struct my_llama_hparams * params) { - printf("%s: n_vocab: %d\n", __func__, params->n_vocab); - printf("%s: n_ctx: %d\n", __func__, params->n_ctx); - printf("%s: n_embd: %d\n", __func__, params->n_embd); - printf("%s: n_mult: %d\n", __func__, params->n_mult); - printf("%s: n_head: %d\n", __func__, params->n_head); - printf("%s: n_ff: %d\n", __func__, params->n_ff); - printf("%s: n_layer: %d\n", __func__, params->n_layer); - printf("%s: n_rot: %d\n", __func__, params->n_rot); + printf("%s: n_vocab: %u\n", __func__, params->n_vocab); + printf("%s: n_ctx: %u\n", __func__, params->n_ctx); + printf("%s: n_embd: %u\n", __func__, params->n_embd); + printf("%s: n_mult: %u\n", __func__, params->n_mult); + printf("%s: n_head: %u\n", __func__, params->n_head); + printf("%s: n_ff: %u\n", __func__, params->n_ff); + printf("%s: n_layer: %u\n", __func__, params->n_layer); + printf("%s: n_rot: %u\n", __func__, params->n_rot); } static void init_model(struct my_llama_model * model) { @@ -350,25 +350,25 @@ static void init_model(struct my_llama_model * model) { model->train_tokens = 0; model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); - printf("[%s:GG] Allocating [%d] x [%d] = [%d] float space for model->tok_embeddings\n",__func__,n_embd , n_vocab, n_embd * n_vocab); + printf("[%s:GG] Allocating [%u] x [%u] = [%u] float space for model->tok_embeddings\n",__func__,n_embd , n_vocab, n_embd * n_vocab); model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - printf("[%s:GG] Allocating [%d] float space for model->norm\n",__func__,n_embd); + printf("[%s:GG] Allocating [%u] float space for model->norm\n",__func__,n_embd); model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for model->output\n",__func__,n_embd, n_vocab, n_embd * n_vocab); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for model->output\n",__func__,n_embd, n_vocab, n_embd * n_vocab); // printing the per-layer allocations here so we dont print in the for loop. - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wq for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wk for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wv for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wo for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wq for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wk for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wv for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wo for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer); - printf("[%s:GG] Allocating [%d] float space for layer.ffn_norm for [%d] layers\n",__func__,n_embd, n_layer); + printf("[%s:GG] Allocating [%u] float space for layer.ffn_norm for [%u] layers\n",__func__,n_embd, n_layer); - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.w1 for [%d] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer); - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.w2 for [%d] layers\n",__func__, n_embd, n_ff, n_ff * n_embd, n_layer); - printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.w3 for [%d] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w1 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w2 for [%u] layers\n",__func__, n_embd, n_ff, n_ff * n_embd, n_layer); + printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w3 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer); ggml_set_name(model->tok_embeddings, "tok_embeddings.weight"); ggml_set_name(model->norm, "norm.weight"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 67d2d3293..74dcc642a 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1623,7 +1623,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { uint32_t n_ctx; in.read((char *)&n_ctx, sizeof(n_ctx)); if (n_ctx > llama_n_ctx(ctx)) { - fprintf(stderr, "%s: %s has been computed with %d, while the current context is %d. Increase it with -c and retry\n", + fprintf(stderr, "%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n", __func__, params.logits_file.c_str(), n_ctx, params.n_ctx); } diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index bfdf124d7..e78ab185d 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -111,13 +111,13 @@ static const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down"; static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up"; static void print_params(struct my_llama_hparams * params) { - printf("%s: n_vocab: %d\n", __func__, params->n_vocab); - printf("%s: n_ctx: %d\n", __func__, params->n_ctx); - printf("%s: n_embd: %d\n", __func__, params->n_embd); - printf("%s: n_head: %d\n", __func__, params->n_head); - printf("%s: n_ff: %d\n", __func__, params->n_ff); - printf("%s: n_layer: %d\n", __func__, params->n_layer); - printf("%s: n_rot: %d\n", __func__, params->n_rot); + printf("%s: n_vocab: %u\n", __func__, params->n_vocab); + printf("%s: n_ctx: %u\n", __func__, params->n_ctx); + printf("%s: n_embd: %u\n", __func__, params->n_embd); + printf("%s: n_head: %u\n", __func__, params->n_head); + printf("%s: n_ff: %u\n", __func__, params->n_ff); + printf("%s: n_layer: %u\n", __func__, params->n_layer); + printf("%s: n_rot: %u\n", __func__, params->n_rot); } static void set_param_model(struct my_llama_model * model) { diff --git a/ggml.c b/ggml.c index aefcda6d4..8224652a9 100644 --- a/ggml.c +++ b/ggml.c @@ -17909,7 +17909,7 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * ptr += ggml_nbytes(tensor); - fprintf(stderr, "%s: loaded leaf %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); + fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); } } @@ -18012,7 +18012,7 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * result->nodes[i] = tensor; - fprintf(stderr, "%s: loaded node %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); + fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); } } } diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index a0b5b043d..91939e276 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -38,8 +38,8 @@ term ::= [0-9]+)"""; // pretty print error message before asserting if (expected_pair.first != key || expected_pair.second != value) { - fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second); - fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value); + fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second); + fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value); fprintf(stderr, "expected_pair != actual_pair\n"); } @@ -96,9 +96,9 @@ term ::= [0-9]+)"""; // pretty print error message before asserting if (expected_element.type != element.type || expected_element.value != element.value) { - fprintf(stderr, "index: %d\n", index); - fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value); + fprintf(stderr, "index: %u\n", index); + fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); fprintf(stderr, "expected_element != actual_element\n"); } @@ -144,8 +144,8 @@ term ::= [0-9]+)"""; // pretty print error message before asserting if (expected_pair.first != key || expected_pair.second != value) { - fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second); - fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value); + fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second); + fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value); fprintf(stderr, "expected_pair != actual_pair\n"); } @@ -235,9 +235,9 @@ term ::= [0-9]+)"""; // pretty print error message before asserting if (expected_element.type != element.type || expected_element.value != element.value) { - fprintf(stderr, "index: %d\n", index); - fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value); + fprintf(stderr, "index: %u\n", index); + fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); fprintf(stderr, "expected_element != actual_element\n"); } diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 16ebe753f..27ca4d265 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -180,8 +180,8 @@ int main() if (expected_element.type != element->type || expected_element.value != element->value) { fprintf(stderr, "index: %d\n", index); - fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %d\n", element->type, element->value); + fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value); fprintf(stderr, "expected_element != actual_element\n"); } From 1dcc3fde004787e6fc4d84c9de0bb34cd2901a3e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 18 Feb 2024 18:21:52 +0200 Subject: [PATCH 05/11] common : fix ub (#5530) --- common/common.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 9ffc3951f..489462b5a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1801,7 +1801,8 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { if (cs_curr[j] < 0) { continue; } if (seqs.find(cs_curr[j]) == seqs.end()) { if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } - seqs[cs_curr[j]] = seqs.size(); + const size_t sz = seqs.size(); + seqs[cs_curr[j]] = sz; } } if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } From 66c1968f7a2e895675425e875b6589f1233a1b52 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Sun, 18 Feb 2024 08:23:16 -0800 Subject: [PATCH 06/11] server : graceful server shutdown (#5244) This updates the server queue to support graceful shutdown of the server on signals. --- examples/server/server.cpp | 23 ++++++++++++++++++++++- examples/server/utils.hpp | 20 +++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a0b46970b..7800c6e7e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -28,6 +28,7 @@ #include #include #include +#include using json = nlohmann::json; @@ -2511,6 +2512,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con } } +std::function shutdown_handler; +inline void signal_handler(int signal) { shutdown_handler(signal); } + int main(int argc, char **argv) { #if SERVER_VERBOSE != 1 @@ -3128,8 +3132,25 @@ int main(int argc, char **argv) std::placeholders::_2, std::placeholders::_3 )); - llama.queue_tasks.start_loop(); + shutdown_handler = [&](int) { + llama.queue_tasks.terminate(); + }; + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + llama.queue_tasks.start_loop(); + svr.stop(); t.join(); llama_backend_free(); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 548548962..0ee670dba 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -220,6 +220,7 @@ inline std::string format_chatml(std::vector messages) struct llama_server_queue { int id = 0; std::mutex mutex_tasks; + bool running; // queues std::vector queue_tasks; std::vector queue_tasks_deferred; @@ -278,9 +279,18 @@ struct llama_server_queue { queue_tasks_deferred.clear(); } - // Start the main loop. This call is blocking - [[noreturn]] + // end the start_loop routine + void terminate() { + { + std::unique_lock lock(mutex_tasks); + running = false; + } + condition_tasks.notify_all(); + } + + // Start the main loop. void start_loop() { + running = true; while (true) { // new task arrived LOG_VERBOSE("have new task", {}); @@ -324,8 +334,12 @@ struct llama_server_queue { { std::unique_lock lock(mutex_tasks); if (queue_tasks.empty()) { + if (!running) { + LOG_VERBOSE("ending start_loop", {}); + return; + } condition_tasks.wait(lock, [&]{ - return !queue_tasks.empty(); + return (!queue_tasks.empty() || !running); }); } } From 36376abe05a12a8cb3af548a4af9b8d0e2e69597 Mon Sep 17 00:00:00 2001 From: Pierrick Hymbert Date: Sun, 18 Feb 2024 17:30:09 +0100 Subject: [PATCH 07/11] server : --n-predict option document and cap to max value (#5549) * server: document --n-predict * server: ensure client request cannot override n_predict if set * server: fix print usage LF in new --n-predict option --- examples/server/README.md | 1 + examples/server/server.cpp | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/server/README.md b/examples/server/README.md index 249368749..fe5cd8d5d 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -39,6 +39,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437 - `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA. - `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w` - `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n` +- `-n, --n-predict`: Set the maximum tokens to predict (default: -1) ## Build diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7800c6e7e..7aa706e95 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -159,6 +159,7 @@ struct llama_client_slot int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; + int32_t n_predict = -1; int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens_processed = 0; @@ -410,6 +411,7 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; + slot.n_predict = params.n_predict; LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); @@ -546,6 +548,15 @@ struct llama_server_context slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { + // Might be better to reject the request with a 400 ? + LOG_WARNING("Max tokens to predict exceeds server configuration", { + {"params.n_predict", slot->params.n_predict}, + {"slot.n_predict", slot->n_predict}, + }); + slot->params.n_predict = slot->n_predict; + } + // infill if (data.count("input_prefix") != 0) { @@ -1053,6 +1064,7 @@ struct llama_server_context return json { {"n_ctx", slot.n_ctx}, + {"n_predict", slot.n_predict}, {"model", params.model_alias}, {"seed", slot.params.seed}, {"temperature", slot.sparams.temp}, @@ -1915,13 +1927,14 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); printf(" --log-disable disables logging to a file.\n"); printf("\n"); + printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); printf(" --chat-template FORMAT_NAME"); - printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str()); + printf(" set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str()); printf("\n"); } From e75c6279d1c8e7abb82a331f5de7124eed402de2 Mon Sep 17 00:00:00 2001 From: Pierrick Hymbert Date: Sun, 18 Feb 2024 17:31:28 +0100 Subject: [PATCH 08/11] server : enhanced health endpoint (#5548) * server: enrich health endpoint with available slots, return 503 if not slots are available * server: document new status no slot available in the README.md --- examples/server/README.md | 1 + examples/server/server.cpp | 31 +++++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index fe5cd8d5d..5e3ae833b 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -136,6 +136,7 @@ node index.js - `{"status": "loading model"}` if the model is still being loaded. - `{"status": "error"}` if the model failed to load. - `{"status": "ok"}` if the model is successfully loaded and the server is ready for further requests mentioned below. + - `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slot are currently available - **POST** `/completion`: Given a `prompt`, it returns the predicted completion. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7aa706e95..8145af867 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2578,8 +2578,35 @@ int main(int argc, char **argv) server_state current_state = state.load(); switch(current_state) { case SERVER_STATE_READY: - res.set_content(R"({"status": "ok"})", "application/json"); - res.status = 200; // HTTP OK + if (llama.all_slots_are_idle) { + res.set_content(R"({"status": "ok"})", "application/json"); + res.status = 200; // HTTP OK + } else { + int available_slots = 0; + int processing_slots = 0; + for (llama_client_slot & slot : llama.slots) { + if (slot.available()) { + available_slots++; + } else { + processing_slots++; + } + } + if (available_slots > 0) { + json health = { + {"status", "ok"}, + {"slots_idle", available_slots}, + {"slots_processing", processing_slots}}; + res.set_content(health.dump(), "application/json"); + res.status = 200; // HTTP OK + } else { + json health = { + {"status", "no slot available"}, + {"slots_idle", available_slots}, + {"slots_processing", processing_slots}}; + res.set_content(health.dump(), "application/json"); + res.status = 503; // HTTP Service Unavailable + } + } break; case SERVER_STATE_LOADING_MODEL: res.set_content(R"({"status": "loading model"})", "application/json"); From f3f28c5395cd25b371617981b341616dbdd31e85 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 18 Feb 2024 19:17:00 +0200 Subject: [PATCH 09/11] cmake : fix GGML_USE_SYCL typo (#5555) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ea4d4f19..0c29b5d09 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -526,7 +526,7 @@ if (LLAMA_SYCL) message(STATUS "SYCL found") - add_compile_definitions(GML_USE_SYCL) + add_compile_definitions(GGML_USE_SYCL) if (LLAMA_SYCL_F16) add_compile_definitions(GGML_SYCL_F16) From 689a091bbe0537ee9abff3e15a1d74f5f3561165 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 18 Feb 2024 19:38:06 +0200 Subject: [PATCH 10/11] sampling : do not set min_keep to n_probs (#5564) --- common/sampling.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 53013138a..611c327bb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -121,7 +121,7 @@ static void sampler_queue( struct llama_context * ctx_main, const llama_sampling_params & params, llama_token_data_array & cur_p, - size_t & min_keep) { + size_t min_keep) { const float temp = params.temp; const float dynatemp_range = params.dynatemp_range; const float dynatemp_exponent = params.dynatemp_exponent; @@ -248,10 +248,7 @@ static llama_token llama_sampling_sample_impl( llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { - // temperature sampling - size_t min_keep = std::max(1, params.n_probs); - - sampler_queue(ctx_main, params, cur_p, min_keep); + sampler_queue(ctx_main, params, cur_p, 1); id = llama_sample_token(ctx_main, &cur_p); From c145f8a132b2fe1d1e65987faddbd9a40bef7a12 Mon Sep 17 00:00:00 2001 From: Pierrick Hymbert Date: Sun, 18 Feb 2024 18:39:57 +0100 Subject: [PATCH 11/11] server : slots monitoring endpoint (#5550) --- examples/server/README.md | 64 ++++++++++++++++++++++++++++++++++++++ examples/server/server.cpp | 32 +++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/examples/server/README.md b/examples/server/README.md index 5e3ae833b..ac5133d24 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -40,6 +40,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437 - `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w` - `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n` - `-n, --n-predict`: Set the maximum tokens to predict (default: -1) +- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included. ## Build @@ -381,6 +382,69 @@ Notice that each `probs` is an array of length `n_probs`. }' ``` +- **GET** `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`. + +### Result JSON + +```json +[ + { + "dynatemp_exponent": 1.0, + "dynatemp_range": 0.0, + "frequency_penalty": 0.0, + "grammar": "", + "id": 0, + "ignore_eos": false, + "logit_bias": [], + "min_p": 0.05000000074505806, + "mirostat": 0, + "mirostat_eta": 0.10000000149011612, + "mirostat_tau": 5.0, + "model": "llama-2-7b-32k-instruct.Q2_K.gguf", + "n_ctx": 2048, + "n_keep": 0, + "n_predict": 100000, + "n_probs": 0, + "next_token": { + "has_next_token": true, + "n_remain": -1, + "num_tokens_predicted": 0, + "stopped_eos": false, + "stopped_limit": false, + "stopped_word": false, + "stopping_word": "" + }, + "penalize_nl": true, + "penalty_prompt_tokens": [], + "presence_penalty": 0.0, + "prompt": "Say hello to llama.cpp", + "repeat_last_n": 64, + "repeat_penalty": 1.100000023841858, + "samplers": [ + "top_k", + "tfs_z", + "typical_p", + "top_p", + "min_p", + "temperature" + ], + "seed": 42, + "state": 1, + "stop": [ + "\n" + ], + "stream": false, + "task_id": 0, + "temperature": 0.0, + "tfs_z": 1.0, + "top_k": 40, + "top_p": 0.949999988079071, + "typical_p": 1.0, + "use_penalty_prompt_tokens": false + } +] +``` + ## More examples ### Change system prompt on runtime diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8145af867..4f2e9c898 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -41,6 +41,7 @@ struct server_params int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; + bool slots_endpoint = true; }; bool server_verbose = false; @@ -1926,6 +1927,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); printf(" --log-disable disables logging to a file.\n"); + printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); printf("\n"); printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); @@ -2374,6 +2376,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, log_set_target(stdout); LOG_INFO("logging to file is disabled.", {}); } + else if (arg == "--slots-endpoint-disable") + { + sparams.slots_endpoint = false; + } else if (arg == "--chat-template") { if (++i >= argc) @@ -2619,6 +2625,32 @@ int main(int argc, char **argv) } }); + if (sparams.slots_endpoint) { + svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) { + json slots; + for (llama_client_slot & slot : llama.slots) { + json slot_data = llama.get_formated_generation(slot); + slot_data["id"] = slot.id; + slot_data["task_id"] = slot.task_id; + slot_data["state"] = slot.state; + slot_data["prompt"] = slot.prompt; + slot_data["next_token"] = { + {"has_next_token", slot.has_next_token}, + {"n_remain", slot.n_remaining}, + {"num_tokens_predicted", slot.n_decoded}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + }; + + slots.push_back(slot_data); + } + res.set_content(slots.dump(), "application/json"); + res.status = 200; // HTTP OK + }); + } + svr.set_logger(log_server_request); svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)