diff --git a/Makefile b/Makefile index 068f6ed02..4f26c0463 100644 --- a/Makefile +++ b/Makefile @@ -381,8 +381,13 @@ ifdef LLAMA_BLIS endif # LLAMA_BLIS ifdef LLAMA_CUBLAS - MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include - MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib + ifneq ('', '$(wildcard /opt/cuda)') + CUDA_PATH ?= /opt/cuda + else + CUDA_PATH ?= /usr/local/cuda + endif + MK_CPPFLAGS += -DGGML_USE_CUBLAS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib OBJS += ggml-cuda.o MK_NVCCFLAGS += -use_fast_math ifdef LLAMA_FATAL_WARNINGS diff --git a/README.md b/README.md index d0af5d0b9..507a2888b 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,7 @@ Unless otherwise noted these projects are open-source with permissive licensing: - [withcatai/catai](https://github.com/withcatai/catai) - [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT) - [Msty](https://msty.app) (proprietary) +- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT) --- diff --git a/common/common.cpp b/common/common.cpp index ec596f5a0..18289755c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.yarn_beta_slow = std::stof(argv[i]); + } else if (arg == "--defrag-thold" || arg == "-dt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.defrag_thold = std::stof(argv[i]); } else if (arg == "--samplers") { if (++i >= argc) { invalid_param = true; @@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" -dt N, --defrag-thold N\n"); + printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); @@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; + cparams.defrag_thold = params.defrag_thold; cparams.offload_kqv = !params.no_kv_offload; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); diff --git a/common/common.h b/common/common.h index 3e21579b0..25003df26 100644 --- a/common/common.h +++ b/common/common.h @@ -75,6 +75,7 @@ struct gpt_params { float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length + float defrag_thold = -1.0f; // KV cache defragmentation threshold int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 47de67a93..2cbc9e1fa 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -182,7 +182,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - llama_kv_cache_defrag (ctx); + //llama_kv_cache_defrag (ctx); llama_kv_cache_update (ctx); n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; @@ -213,7 +213,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - llama_kv_cache_defrag (ctx); + //llama_kv_cache_defrag (ctx); llama_kv_cache_update (ctx); n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index ab7e72aaf..7662ec80c 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -23,18 +23,21 @@ static const std::vector QUANT_OPTIONS = { { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, { "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, + { "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", }, + { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, - { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, + { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, - { "Q3_K_XS",LLAMA_FTYPE_MOSTLY_Q3_K_XS,"3-bit extra small quantization" , }, + { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization" , }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", }, { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, - { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.25 bpw non-linear quantization", }, + { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, + { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", }, { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", }, @@ -292,6 +295,7 @@ int main(int argc, char ** argv) { } if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || + params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) { fprintf(stderr, "\n===============================================================================================\n"); fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2f4acf54a..fff05cbde 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1336,6 +1336,10 @@ struct llama_server_context split_multiprompt_task(task_id, task); } } else { + // an empty prompt can make slot become buggy + if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get().empty()) { + task.data["prompt"] = " "; // add a space so that we have one token + } queue_tasks.post(task); } } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index fb6d4f7d2..dfd28df62 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -523,6 +523,17 @@ typedef struct { } block_iq2_xs; static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); +// 2.5625 bpw quants +#define QR2_S 8 +#define QI2_S (QK_K / (4*QR2_S)) +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; +static_assert(sizeof(block_iq2_s) == sizeof(ggml_fp16_t) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding"); + #define QR3_XXS 8 #define QI3_XXS (QK_K / (4*QR3_XXS)) typedef struct { @@ -560,6 +571,18 @@ typedef struct { } block_iq4_nl; static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding"); +// QR4_XS = 8 is very slightly faster than QR4_XS = 4 +#define QR4_XS 8 +#define QI4_XS (QK_K / (4*QR4_XS)) +typedef struct { + half d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; +static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); + + #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -685,18 +708,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); -// } -// return a; -//#else -// (void) a; -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -//} +#ifdef GGML_CUDA_F16 +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + (void) a; + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} +#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -1689,6 +1714,265 @@ static const __device__ uint64_t iq2xs_grid[512] = { 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, }; +static const __device__ uint64_t iq2s_grid[1024] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +}; + static const __device__ uint32_t iq3xxs_grid[256] = { 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -2037,6 +2321,27 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst } +template +static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_s * x = (const block_iq2_s *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = x[i].qs[QK_K/8+4*ib+il]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +#else + assert(false); +#endif + +} + template static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2134,6 +2439,25 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst } +template +static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq4_xs * x = (const block_iq4_xs *)vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; + } + +} + static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -2230,10 +2554,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[row] = tmp; @@ -2334,10 +2655,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[row] = tmp; @@ -2470,10 +2788,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (tid == 0) { dst[row] = tmp; @@ -2586,10 +2901,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[row] = tmp; @@ -2696,10 +3008,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (tid == 0) { dst[row] = tmp; @@ -2734,11 +3043,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest float amax = fabsf(xi); float sum = xi; -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); - } + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); @@ -4800,6 +5106,54 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( #endif } +// TODO +static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +#if QK_K == 256 + const block_iq2_s * bq2 = (const block_iq2_s *) vbq; + + const int ib32 = iqs; + const int8_t * q8 = bq8_1[ib32].qs; + const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); + const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); + const int grid_l = __vsub4(grid[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid[1] ^ signs1, signs1); + sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1); + sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1); + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); + const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); + const int grid_l = __vsub4(grid[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid[1] ^ signs1, signs1); + sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2); + sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2); + q8 += 8; + } + const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); +#else + (void) ksigns64; + assert(false); + return 0.f; +#endif +#else + (void) ksigns64; + assert(false); + return 0.f; +#endif +} + static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -4963,6 +5317,76 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( return d * (sumi1 + sumi2); } +static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#if QK_K == 256 +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + + const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; + const uint8_t * values = (const uint8_t *)kvalues_iq4nl; + + //// iqs is 0...7 + //const int ib64 = iqs/2; + //const int il = iqs%2; + //const int32_t * q8_1 = (const int *)bq8_1[2*ib64+0].qs + 2*il; + //const int32_t * q8_2 = (const int *)bq8_1[2*ib64+1].qs + 2*il; + //const uint32_t * q4_1 = (const uint32_t *)bq4->qs + 8*ib64 + 2*il; + //const uint32_t * q4_2 = q4_1 + 4; + //const int8_t ls1 = (bq4->scales_l[ib64] & 0xf) | (((bq4->scales_h >> (4*ib64+0)) & 3) << 4); + //const int8_t ls2 = (bq4->scales_l[ib64] >> 4) | (((bq4->scales_h >> (4*ib64+2)) & 3) << 4); + //const float d1 = (float)bq4->d * (ls1 - 32) * __low2float(bq8_1[2*ib64+0].ds); + //const float d2 = (float)bq4->d * (ls2 - 32) * __low2float(bq8_1[2*ib64+1].ds); + //int v1, v2; + //int sumi1 = 0, sumi2 = 0; + //for (int j = 0; j < 2; ++j) { + // get_int_from_table_16(q4_1[j], values, v1, v2); + // sumi1 = __dp4a(v2, q8_1[j+4], __dp4a(v1, q8_1[j+0], sumi1)); + // get_int_from_table_16(q4_2[j], values, v1, v2); + // sumi2 = __dp4a(v2, q8_2[j+4], __dp4a(v1, q8_2[j+0], sumi2)); + //} + //return d1 * sumi1 + d2 * sumi2; + + // iqs is 0...7 + const int ib32 = iqs; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; + const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds); + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + get_int_from_table_16(q4[j], values, v1, v2); + sumi1 = __dp4a(v1, q8[j+0], sumi1); + sumi2 = __dp4a(v2, q8[j+4], sumi2); + } + return d * (sumi1 + sumi2); + + //// iqs is 0...15 + //const int ib32 = iqs/2; + //const int il = iqs%2; + //const int32_t * q8 = (const int *)bq8_1[ib32].qs + 2*il; + //const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il; + //const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4); + //const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds); + //int v1, v2; + //int sumi1 = 0, sumi2 = 0; + //for (int j = 0; j < 2; ++j) { + // get_int_from_table_16(q4[j], values, v1, v2); + // sumi1 = __dp4a(v1, q8[j+0], sumi1); + // sumi2 = __dp4a(v2, q8[j+4], sumi2); + //} + //return d * (sumi1 + sumi2); +#else + assert(false); + return 0.f; +#endif +#else + assert(false); + return 0.f; +#endif +} + template static __device__ __forceinline__ void mul_mat_q( @@ -5883,10 +6307,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (tid == 0) { #ifdef GGML_CUDA_F16 @@ -5936,10 +6357,7 @@ static __global__ void mul_mat_p021_f16_f32( const int idst = channel*nrows_dst + row_dst; // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[idst] = tmp; @@ -5982,10 +6400,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous } // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[idst] = tmp; @@ -6996,6 +7411,12 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, dequantize_block_iq2_xs<<>>(vx, y); } +template +static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_s<<>>(vx, y); +} + template static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; @@ -7020,6 +7441,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, dequantize_block_iq4_nl<<>>(vx, y); } +template +static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_xs<<>>(vx, y); +} + template static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; @@ -7057,12 +7484,16 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq2_xxs_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; + case GGML_TYPE_IQ2_S: + return dequantize_row_iq2_s_cuda; case GGML_TYPE_IQ3_XXS: return dequantize_row_iq3_xxs_cuda; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_IQ4_XS: + return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -7098,12 +7529,16 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq2_xxs_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; + case GGML_TYPE_IQ2_S: + return dequantize_row_iq2_s_cuda; case GGML_TYPE_IQ3_XXS: return dequantize_row_iq3_xxs_cuda; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_IQ4_XS: + return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: @@ -8079,8 +8514,8 @@ static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual *actual_size = look_ahead_size; g_cuda_pool_size[device] += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC - fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, - (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); + fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[device]/1024/1024), (uint32_t)(size/1024/1024)); #endif return ptr; } @@ -8166,7 +8601,7 @@ static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual g_cuda_pool_used[device] += size; #ifdef DEBUG_CUDA_MALLOC - printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr); + printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr); #endif return ptr; @@ -8176,7 +8611,7 @@ static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); #ifdef DEBUG_CUDA_MALLOC - printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr); + printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr); #endif g_cuda_pool_used[device] -= size; @@ -8848,9 +9283,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_RDNA2 ? 128 : 64; default: @@ -8874,9 +9311,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: @@ -8971,6 +9410,10 @@ static void ggml_cuda_op_mul_mat_vec_q( mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_q_cuda + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); @@ -8983,6 +9426,10 @@ static void ggml_cuda_op_mul_mat_vec_q( mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_q_cuda + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); @@ -11710,7 +12157,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } ggml_type a_type = a->type; if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S) { + a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S || + a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) { if (b->ne[1] == 1 && ggml_nrows(b) > 1) { return false; } diff --git a/ggml-metal.m b/ggml-metal.m index 3d6b01263..9eba2f5d2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -62,8 +62,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, @@ -87,8 +89,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, @@ -108,8 +112,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, @@ -126,8 +132,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, @@ -144,8 +152,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_ROPE_F32, GGML_METAL_KERNEL_TYPE_ROPE_F16, GGML_METAL_KERNEL_TYPE_ALIBI_F32, @@ -458,8 +468,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); @@ -483,8 +495,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); @@ -504,8 +518,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); @@ -522,8 +538,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); @@ -540,8 +558,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); @@ -1358,8 +1378,10 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } @@ -1500,6 +1522,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; + } break; case GGML_TYPE_IQ1_S: { nth0 = 4; @@ -1512,6 +1540,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -1544,9 +1578,9 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) { + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -1559,7 +1593,7 @@ static bool ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ4_NL) { + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -1658,8 +1692,10 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } @@ -1803,6 +1839,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; case GGML_TYPE_IQ1_S: { nth0 = 4; @@ -1815,6 +1857,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -1863,9 +1911,9 @@ static bool ggml_metal_graph_compute( [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; } - if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || - src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || - src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) { + if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || + src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || + src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) { @@ -1878,7 +1926,7 @@ static bool ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src2t == GGML_TYPE_IQ4_NL) { + else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -1925,8 +1973,10 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index b3bf40539..689411903 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2519,6 +2519,14 @@ typedef struct { } block_iq2_xs; // 74 bytes / block for QK_K = 256, so 2.3125 bpw +// 2.5625 bpw quants +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; + typedef struct { half d; uint8_t qs[3*QK_K/8]; @@ -2552,6 +2560,13 @@ typedef struct { uint8_t qs[QK4_NL/2]; } block_iq4_nl; +typedef struct { + half d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; + //====================================== dot products ========================= void kernel_mul_mv_q2_K_f32_impl( @@ -3774,6 +3789,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = { 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, }; +constexpr constant static uint64_t iq2s_grid[1024] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +}; + constexpr constant static uint32_t iq3xxs_grid[256] = { 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -4572,6 +4846,139 @@ kernel void kernel_mul_mv_iq3_s_f32( kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } +void kernel_mul_mv_iq2_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (int l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (int j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } + } + sumf[row] += d1 * sum[0] + d2 * sum[1]; + + dh += nb*sizeof(block_iq2_s)/2; + qs += nb*sizeof(block_iq2_s); + qh += nb*sizeof(block_iq2_s); + sc += nb*sizeof(block_iq2_s); + signs += nb*sizeof(block_iq2_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, @@ -4760,6 +5167,100 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } +void kernel_mul_mv_iq4_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup float * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( device const void * src0, @@ -4817,6 +5318,35 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup float * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template @@ -5188,6 +5718,25 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & } } +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + template void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 @@ -5219,6 +5768,26 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 } } +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + template kernel void kernel_get_rows( device const void * src0, @@ -5762,8 +6331,10 @@ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_r template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; // // matrix-matrix multiplication @@ -5804,8 +6375,10 @@ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -5858,8 +6431,10 @@ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; // // matrix-vector multiplication @@ -6893,6 +7468,71 @@ kernel void kernel_mul_mv_id_iq3_s_f32( sgitg); } +[[host_name("kernel_mul_mv_id_iq2_s_f32")]] +kernel void kernel_mul_mv_id_iq2_s_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq2_s_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + shared_values, + tgpig, + tiisg, + sgitg); +} + [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel void kernel_mul_mv_id_iq1_s_f32( device const char * ids, @@ -7020,3 +7660,68 @@ kernel void kernel_mul_mv_id_iq4_nl_f32( tiisg, sgitg); } + +[[host_name("kernel_mul_mv_id_iq4_xs_f32")]] +kernel void kernel_mul_mv_id_iq4_xs_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + threadgroup float * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq4_xs_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + shared_values, + tgpig, + tiisg, + sgitg); +} diff --git a/ggml-quants.c b/ggml-quants.c index 3d94d166d..f73d17ce2 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3495,6 +3495,265 @@ static const uint64_t iq2xs_grid[512] = { 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, }; +static const uint64_t iq2s_grid[1024] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +}; + static const uint32_t iq3xxs_grid[256] = { 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -3796,6 +4055,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, } } +// ====================== 2.5625 bpw (de)-quantization + +void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + float db[2]; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const float dl = db[l/2]; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + qs += 4; + signs += 4; + } + } +} + // ====================== 3.0625 bpw (de)-quantization void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) { @@ -3934,6 +4225,29 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, } } +void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * qs = x[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; + y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; + } + y += 32; + qs += 16; + } + } +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { @@ -9330,6 +9644,210 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * #endif } +void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t m1 = vdupq_n_u8(1); + const int32x4_t vzero = vdupq_n_s32(0); + + uint8x16x2_t vs; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); + q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); + q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); + q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); + qs += 8; + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); + q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + signs += 4; + + q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); + q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); + + const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); + sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); + } + sumf += d*(sumi1 + sumi2); + } + + *s = 0.125f * sumf; + +#elif defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; + +#endif + +} + void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -9753,8 +10271,12 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const uint64_t aux64; - __m256i v_gindex; - const uint16_t * gindex = (const uint16_t *)&v_gindex; + typedef union m256i_uint16 { + __m256i reg; + uint16_t s[16]; + } m256i_uint16_t; + + m256i_uint16_t v_gindex; __m256 accum = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -9769,13 +10291,13 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const memcpy(&aux64, sc, 8); sc += 8; const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h); const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8)); - v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5)); + v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5)); const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1); for (int i32 = 0; i32 < 4; ++i32) { const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]], - iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]); + const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]], + iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]); const __m256i dot = mul_add_epi8(q1b, q8b); const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32])); const __m256i p = _mm256_madd_epi16(s16, dot); @@ -9926,6 +10448,134 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * #endif } +void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + ggml_uint8x16x2_t q4bits; + ggml_int8x16x4_t q4b; + ggml_int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + + const int8_t * q8 = y[ibl].qs; + const uint8_t * q4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { + + q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + sumi1 += vaddvq_s32(prod_1) * ls1; + sumi2 += vaddvq_s32(prod_2) * ls2; + + } + + sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); + const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); + sumi1 = _mm256_add_epi32(p_1, sumi1); + sumi2 = _mm256_add_epi32(p_2, sumi2); + } + accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + // ================================ IQ2 quantization ============================================= typedef struct { @@ -9934,22 +10584,25 @@ typedef struct { uint16_t * neighbours; } iq2_entry_t; -static iq2_entry_t iq2_data[3] = { +static iq2_entry_t iq2_data[4] = { + {NULL, NULL, NULL}, {NULL, NULL, NULL}, {NULL, NULL, NULL}, {NULL, NULL, NULL}, }; static inline int iq2_data_index(enum ggml_type type) { - GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S); + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S); return type == GGML_TYPE_IQ2_XXS ? 0 : - type == GGML_TYPE_IQ2_XS ? 1 : 2; + type == GGML_TYPE_IQ2_XS ? 1 : + type == GGML_TYPE_IQ1_S ? 2 : 3; } static inline int iq2_grid_size(enum ggml_type type) { - GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S); + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S); return type == GGML_TYPE_IQ2_XXS ? 256 : - type == GGML_TYPE_IQ2_XS ? 512 : 512; + type == GGML_TYPE_IQ2_XS ? 512 : + type == GGML_TYPE_IQ1_S ? 512 : 1024; } static int iq2_compare_func(const void * left, const void * right) { @@ -10050,11 +10703,79 @@ void iq2xs_init_impl(enum ggml_type type) { 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512, 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680, }; + static const uint16_t kgrid_2bit_1024[1024] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 102, 105, 128, 130, 133, 136, 145, 148, 160, + 165, 170, 257, 260, 262, 265, 272, 274, 277, 280, 289, 292, 320, 322, 325, 328, + 337, 340, 342, 345, 352, 357, 360, 385, 388, 400, 402, 405, 417, 420, 512, 514, + 517, 520, 529, 532, 544, 554, 577, 580, 582, 585, 592, 597, 640, 645, 650, 660, + 674, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1062, 1065, 1088, 1090, 1093, + 1096, 1098, 1105, 1108, 1110, 1113, 1120, 1122, 1125, 1153, 1156, 1158, 1161, 1168, 1173, 1176, + 1185, 1188, 1280, 1282, 1285, 1288, 1290, 1297, 1300, 1302, 1305, 1312, 1317, 1320, 1345, 1348, + 1350, 1353, 1360, 1362, 1365, 1368, 1377, 1380, 1408, 1410, 1413, 1416, 1425, 1428, 1440, 1537, + 1540, 1542, 1545, 1552, 1557, 1600, 1605, 1608, 1617, 1620, 1632, 1665, 1668, 1680, 2048, 2050, + 2053, 2056, 2065, 2068, 2070, 2073, 2080, 2085, 2090, 2113, 2116, 2118, 2121, 2128, 2130, 2133, + 2136, 2145, 2148, 2176, 2181, 2196, 2218, 2305, 2308, 2320, 2322, 2325, 2328, 2337, 2368, 2373, + 2376, 2385, 2388, 2400, 2433, 2448, 2560, 2577, 2580, 2594, 2600, 2602, 2640, 2713, 4097, 4100, + 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4134, 4160, 4162, 4165, 4168, 4177, 4180, 4182, + 4185, 4192, 4194, 4197, 4200, 4225, 4228, 4230, 4240, 4245, 4248, 4257, 4260, 4352, 4354, 4357, + 4360, 4362, 4369, 4372, 4374, 4377, 4384, 4386, 4389, 4392, 4417, 4420, 4422, 4425, 4432, 4434, + 4437, 4440, 4449, 4452, 4480, 4482, 4485, 4488, 4497, 4500, 4609, 4612, 4617, 4624, 4629, 4641, + 4644, 4672, 4677, 4689, 4692, 4737, 4740, 4752, 5120, 5122, 5125, 5128, 5137, 5140, 5142, 5145, + 5152, 5157, 5160, 5185, 5188, 5190, 5193, 5200, 5202, 5205, 5208, 5217, 5220, 5248, 5250, 5253, + 5256, 5265, 5268, 5280, 5377, 5380, 5382, 5385, 5392, 5394, 5397, 5400, 5409, 5412, 5440, 5442, + 5445, 5448, 5457, 5460, 5472, 5505, 5508, 5520, 5632, 5637, 5640, 5649, 5652, 5664, 5697, 5700, + 5712, 5760, 5802, 6145, 6148, 6150, 6153, 6160, 6165, 6168, 6177, 6208, 6210, 6213, 6216, 6225, + 6228, 6240, 6273, 6276, 6400, 6402, 6405, 6408, 6417, 6420, 6432, 6465, 6468, 6480, 6505, 6562, + 6660, 6672, 6720, 6742, 8192, 8194, 8197, 8200, 8209, 8212, 8214, 8217, 8224, 8229, 8234, 8257, + 8260, 8272, 8274, 8277, 8292, 8320, 8330, 8340, 8362, 8449, 8452, 8464, 8466, 8469, 8481, 8512, + 8514, 8517, 8529, 8532, 8544, 8577, 8580, 8592, 8704, 8714, 8738, 8744, 8746, 8772, 8784, 8840, + 8842, 8872, 9217, 9220, 9222, 9225, 9232, 9237, 9240, 9249, 9252, 9280, 9282, 9285, 9288, 9297, + 9300, 9312, 9345, 9348, 9360, 9472, 9477, 9480, 9489, 9492, 9504, 9537, 9540, 9552, 9574, 9600, + 9729, 9732, 9744, 9792, 9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500, + 10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410, + 16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513, + 16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674, + 16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785, + 16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025, + 17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476, + 17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665, + 17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760, + 17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085, + 18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528, + 18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948, + 18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548, + 20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740, + 20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865, + 20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510, + 21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636, + 21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054, + 22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800, + 22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645, + 24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912, + 24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680, + 25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880, + 26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850, + 32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060, + 33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345, + 33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873, + 33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176, + 34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076, + 35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928, + 36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200, + 37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968, + 38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976, + 39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130, + 41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121, + 42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690, + }; const int kmap_size = 43692; - const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2; + //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2; + const int nwant = type == GGML_TYPE_IQ1_S ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2; const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 : - type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : kgrid_1bit_512; + type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : + type == GGML_TYPE_IQ1_S ? kgrid_1bit_512 : kgrid_2bit_1024; uint64_t * kgrid_q2xs; int * kmap_q2xs; uint16_t * kneighbors_q2xs; @@ -10151,7 +10872,7 @@ void iq2xs_init_impl(enum ggml_type type) { } void iq2xs_free_impl(enum ggml_type type) { - GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S); + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S); const int gindex = iq2_data_index(type); if (iq2_data[gindex].grid) { free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL; @@ -11451,23 +12172,23 @@ static inline int best_index_int8(int n, const int8_t * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } -static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x, - ggml_fp16_t * dh, uint8_t * q4, - float * weight, uint8_t * L, +static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x, + ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l, + float * scales, float * weight, uint8_t * L, const int8_t * values, const float * quant_weights) { const int ntry = 7; float sigma2 = 0; - for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j]; - sigma2 *= 2.f/QK4_NL; + for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/super_block_size; - const int nb = QK4_NL/block_size; + memset(q4, 0, super_block_size/2); + dh[0] = GGML_FP32_TO_FP16(0.f); - memset(q4, 0, QK4_NL/2); - for (int ib = 0; ib < nb; ++ib) { - dh[ib] = GGML_FP32_TO_FP16(0.f); + float max_scale = 0, amax_scale = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { const float * xb = x + ib*block_size; if (quant_weights) { const float * qw = quant_weights + ib*block_size; @@ -11483,6 +12204,7 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE } } if (!amax) { + scales[ib] = 0; continue; } float d = -max/values[0]; @@ -11496,7 +12218,6 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE sumqx += w*q*xb[j]; sumq2 += w*q*q; } - float best_id = id; d = sumqx/sumq2; float best = d*sumqx; for (int itry = -ntry; itry <= ntry; ++itry) { @@ -11512,15 +12233,47 @@ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RE } if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { d = sumqx/sumq2; best = d * sumqx; - best_id = id; } } - dh[ib] = GGML_FP32_TO_FP16(d); - for (int j = 0; j < block_size; ++j) { - L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]); + scales[ib] = d; + float abs_d = fabsf(d); + if (abs_d > amax_scale) { + amax_scale = abs_d; max_scale = d; } } - for (int i = 0; i < QK4_NL/32; ++i) { + + if (super_block_size/block_size > 1) { + int nb = super_block_size/block_size; + memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t)); + float d = -max_scale/32; + dh[0] = GGML_FP32_TO_FP16(d); + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + int l = nearest_int(id*scales[ib]); + l = MAX(-32, MIN(31, l)); + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + uint8_t * Lb = L + ib*block_size; + const float * xb = x + ib*block_size; + for (int j = 0; j < block_size; ++j) { + Lb[j] = best_index_int8(16, values, idl*xb[j]); + } + l += 32; + uint8_t l_l = l & 0xf; + uint8_t l_h = l >> 4; + if (ib%2 == 0) scales_l[ib/2] = l_l; + else scales_l[ib/2] |= (l_l << 4); + scales_h[ib/8] |= (l_h << 2*(ib%8)); + } + } else { + dh[0] = GGML_FP32_TO_FP16(scales[0]); + float id = scales[0] ? 1/scales[0] : 0; + for (int j = 0; j < super_block_size; ++j) { + L[j] = best_index_int8(16, values, id*x[j]); + } + } + + for (int i = 0; i < super_block_size/32; ++i) { for (int j = 0; j < 16; ++j) { q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); } @@ -11533,12 +12286,16 @@ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, i int nblock = n_per_row/QK4_NL; char * qrow = (char *)dst; uint8_t L[QK4_NL]; - float weight[32]; + float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; for (int row = 0; row < nrow; ++row) { block_iq4_nl * iq4 = (block_iq4_nl *)qrow; for (int ibl = 0; ibl < nblock; ++ibl) { const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; - quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw); + quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, qw); } src += n_per_row; qrow += nblock*sizeof(block_iq4_nl); @@ -11557,3 +12314,228 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest quantize_iq4_nl(x, y, 1, k, NULL, NULL); } +size_t quantize_iq4_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + uint8_t L[QK_K]; + float weight[32]; + float scales[QK_K/32]; + for (int row = 0; row < nrow; ++row) { + block_iq4_xs * iq4 = (block_iq4_xs *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; + quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l, + scales, weight, L, kvalues_iq4nl, qw); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_xs); + } + return nrow * nblock * sizeof(block_iq4_xs); +} + +void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_iq4_xs * restrict y = vy; + quantize_row_iq4_xs_reference(x, y, k); +} + +void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) { + assert(k % QK_K == 0); + quantize_iq4_xs(x, y, 1, k, NULL, NULL); +} + +// =============================== 2.5625 bpw + +static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { + + const int gindex = iq2_data_index(GGML_TYPE_IQ2_S); + + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; + + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 3; + + const int nbl = n/256; + + block_iq2_s * y = vy; + + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_s)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + } + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + const int i8 = 2*ib + k; + y[ibl].qs[i8] = grid_index & 255; + y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4)); + y[ibl].qs[QK_K/8 + i8] = block_signs[k]; + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); + } + } +} + +size_t quantize_iq2_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int row = 0; row < nrow; ++row) { + quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_s); + } + return nrow * nblock * sizeof(block_iq2_s); +} + +void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int k) { + assert(k % QK_K == 0); + quantize_iq2_s(x, y, 1, k, NULL, NULL); +} + +void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_iq2_s * restrict y = vy; + quantize_row_iq2_s_reference(x, y, k); +} diff --git a/ggml-quants.h b/ggml-quants.h index 303b0b6f9..2c61134c4 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -182,6 +182,15 @@ typedef struct { } block_iq2_xs; static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); +// 2.5625 bpw quants +typedef struct { + ggml_fp16_t d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; +static_assert(sizeof(block_iq2_s) == sizeof(ggml_fp16_t) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding"); + // (Almost) "true" 3-bit quantization. // Due to the need to use blocks as per ggml design, it ends up using // 3.0625 bpw because of the 16-bit scale for each block of 256. @@ -221,6 +230,14 @@ typedef struct { } block_iq4_nl; static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding"); +typedef struct { + ggml_fp16_t d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; +static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); + #ifdef __cplusplus extern "C" { #endif @@ -241,7 +258,9 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k); void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k); +void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int k); void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k); +void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); @@ -258,7 +277,9 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); @@ -276,9 +297,11 @@ void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRI void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); // Dot product @@ -295,9 +318,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // @@ -305,9 +330,11 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const // size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t quantize_iq2_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t quantize_iq4_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index c6c3c6e6f..835967fb6 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -8126,23 +8126,51 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -static void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale, - const sycl::nd_item<3> &item_ct1, float *buf) { + +template +static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, + const int nrows_y, const float scale, const float max_bias, const float m0, + const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { + const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + const int tid = item_ct1.get_local_id(2); const int rowx = item_ct1.get_group(2); const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - const int block_size = item_ct1.get_local_range(2); + const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + float slope = 0.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = rowx/nrows_y; // head index + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = sycl::pow(base, float(exp)); + } + + float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols; float max_val = -INFINITY; - for (int col = tid; col < ncols; col += block_size) { + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - max_val = sycl::max(max_val, x[ix] * scale + (y ? y[iy] : 0.0f)); + + const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + + vals[col] = val; + max_val = sycl::max(max_val, val); } // find the max value in the block @@ -8151,30 +8179,12 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in if (warp_id == 0) { buf[lane_id] = -INFINITY; } - /* - DPCT1118:12: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:60: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. - */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); if (lane_id == 0) { buf[warp_id] = max_val; } - /* - DPCT1118:13: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:61: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. - */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); max_val = buf[lane_id]; max_val = warp_reduce_max(max_val, item_ct1); @@ -8182,13 +8192,16 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in float tmp = 0.f; - for (int col = tid; col < ncols; col += block_size) { - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - const float val = - sycl::native::exp((x[ix] * scale + (y ? y[iy] : 0.0f)) - max_val); +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = sycl::native::exp(vals[col] - max_val); tmp += val; - dst[ix] = val; + vals[col] = val; } // find the sum of exps in the block @@ -8197,40 +8210,29 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in if (warp_id == 0) { buf[lane_id] = 0.f; } - /* - DPCT1118:14: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:62: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. - */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); if (lane_id == 0) { buf[warp_id] = tmp; } - /* - DPCT1118:15: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:63: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. - */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); tmp = buf[lane_id]; tmp = warp_reduce_sum(tmp, item_ct1); } - const float inv_tmp = 1.f / tmp; + const float inv_sum = 1.f / tmp; - for (int col = tid; col < ncols; col += block_size) { - const int i = rowx*ncols + col; - dst[i] *= inv_tmp; +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + const int idst = rowx*ncols + col; + dst[idst] = vals[col] * inv_sum; } } @@ -10867,35 +10869,96 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, }); } -static void soft_max_f32_sycl(const float *x, const float *y, float *dst, - const int ncols_x, const int nrows_x, - const int nrows_y, const float scale, +template +static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, + const int nrows_y, const float scale, const float max_bias, const float m0, + const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, + const size_t n_local_scratch, dpct::queue_ptr stream) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor local_buf_acc(n_local_scratch, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + soft_max_f32(x, mask, pos, dst, ncols_par, + nrows_y, scale, max_bias, m0, + m1, n_head_log2, item_ct1, + local_buf_acc.get_pointer()); + }); + }); +} + +static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos, + float * dst, const int ncols_x, const int nrows_x, + const int nrows_y, const float scale, const float max_bias, dpct::queue_ptr stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2; const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, 1, nrows_x); - /* - DPCT1049:46: The work-group size passed to the SYCL kernel may exceed the - limit. To get the device limit, query info::device::max_work_group_size. - Adjust the work-group size if needed. - */ - stream->submit([&](sycl::handler &cgh) { - /* - DPCT1101:96: 'SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE' expression was - replaced with a value. Modify the code to use the original expression, - provided in comments, if it is correct. - */ - sycl::local_accessor buf_acc_ct1( - sycl::range<1>(32 /*SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE*/), cgh); + const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE); + static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - soft_max_f32(x, y, dst, ncols_x, nrows_y, scale, item_ct1, - buf_acc_ct1.get_pointer()); - }); - }); + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const size_t local_mem_size = stream->get_device().get_info(); + if (n_local_scratch*sizeof(float) < local_mem_size) { + switch (ncols_x) { + case 32: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 64: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 128: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 256: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 512: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 1024: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 2048: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 4096: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + default: + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + } + } else { + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, WARP_SIZE, stream); + } } template @@ -12435,14 +12498,35 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src0->ne[1]; float scale = 1.0f; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias = 0.0f; - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + memcpy(&scale, dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - (void) dst; + // positions tensor + float * src2_dd = nullptr; + sycl_pool_alloc src2_f; + + ggml_tensor * src2 = dst->src[2]; + const bool use_src2 = src2 != nullptr; + + if (use_src2) { + const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU; + + if (src2_on_device) { + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra; + src2_dd = (float *) src2_extra->data_device[g_main_device]; + } else { + src2_dd = src2_f.alloc(ggml_nelements(src2)); + SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream)); + } + } + + soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, + nrows_x, nrows_y, scale, max_bias, main_stream); } inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1, diff --git a/ggml.c b/ggml.c index 1d81553f4..d66db3352 100644 --- a/ggml.c +++ b/ggml.c @@ -690,6 +690,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ2_S] = { + .type_name = "iq2_s", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_s), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_s, + .from_float = quantize_row_iq2_s, + .from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference, + .vec_dot = ggml_vec_dot_iq2_s_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", .blck_size = QK_K, @@ -714,6 +726,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, + [GGML_TYPE_IQ4_XS] = { + .type_name = "iq4_xs", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_xs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, + .from_float = quantize_row_iq4_xs, + .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, + .vec_dot = ggml_vec_dot_iq4_xs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -2316,7 +2340,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; + case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; + case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7751,7 +7777,9 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: { ggml_compute_forward_add_q_f32(params, dst); } break; @@ -8031,7 +8059,9 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: { ggml_compute_forward_add1_q_f32(params, dst); } break; @@ -8156,7 +8186,9 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: default: { GGML_ASSERT(false); @@ -11055,7 +11087,9 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: { ggml_compute_forward_out_prod_q_f32(params, dst); } break; @@ -11244,7 +11278,9 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: default: { GGML_ASSERT(false); @@ -11447,7 +11483,9 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: { ggml_compute_forward_get_rows_q(params, dst); } break; @@ -12148,7 +12186,9 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -12232,7 +12272,9 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -19482,6 +19524,7 @@ void ggml_quantize_init(enum ggml_type type) { switch (type) { case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break; case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break; @@ -19768,6 +19811,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i result = quantize_iq3_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); GGML_ASSERT(result == row_size * nrows); } break; + case GGML_TYPE_IQ2_S: + { + GGML_ASSERT(start % QK_K == 0); + GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = ggml_row_size(type, n_per_row); + result = quantize_iq2_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + GGML_ASSERT(result == row_size * nrows); + } break; case GGML_TYPE_IQ1_S: { GGML_ASSERT(start % QK_K == 0); @@ -19786,6 +19838,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); GGML_ASSERT(result == row_size * nrows); } break; + case GGML_TYPE_IQ4_XS: + { + GGML_ASSERT(start % QK4_NL == 0); + GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = ggml_row_size(type, n_per_row); + result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + GGML_ASSERT(result == row_size * nrows); + } break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); diff --git a/ggml.h b/ggml.h index 75fd035a4..23b768640 100644 --- a/ggml.h +++ b/ggml.h @@ -351,6 +351,8 @@ extern "C" { GGML_TYPE_IQ1_S = 19, GGML_TYPE_IQ4_NL = 20, GGML_TYPE_IQ3_S = 21, + GGML_TYPE_IQ2_S = 22, + GGML_TYPE_IQ4_XS = 23, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -391,6 +393,8 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors }; // available tensor operations: diff --git a/llama.cpp b/llama.cpp index 28430254f..464e1b89b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1641,6 +1641,7 @@ struct llama_cparams { float yarn_attn_factor; float yarn_beta_fast; float yarn_beta_slow; + float defrag_thold; bool mul_mat_q; bool offload_kqv; @@ -2579,9 +2580,11 @@ struct llama_model_loader { case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; + case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; default: { @@ -2933,10 +2936,13 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; - case LLAMA_FTYPE_MOSTLY_Q3_K_XS:return "Q3_K - Extra small"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; @@ -4894,8 +4900,8 @@ static struct ggml_tensor * llm_build_kqv( ggml_mul_mat_set_prec(kq, GGML_PREC_F32); } -#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL) -#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL") +#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) +#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, and Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") if (hparams.f_max_alibi_bias > 0.0f) { @@ -5114,16 +5120,16 @@ struct llm_build_context { struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - for (int i = 0; i < n_kv; ++i) { - const int id = ids[i]; + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; - if (i == id || id == n_kv) { + if (i == id || id == ids.size()) { continue; } - int nm = 1; + uint32_t nm = 1; - while (i + nm < n_kv && (int) ids[i + nm] == id + nm) { + while (i + nm < ids.size() && ids[i + nm] == id + nm) { nm++; } @@ -5155,6 +5161,8 @@ struct llm_build_context { i += nm - 1; } + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); + return gf; } @@ -7935,6 +7943,8 @@ static int llama_decode_internal( batch.seq_id = seq_id_arr.data(); } + llama_kv_cache_update(&lctx); + // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it if (kv_self.head > kv_self.used + 2*n_tokens) { @@ -7953,8 +7963,6 @@ static int llama_decode_internal( //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); - llama_kv_cache_update(&lctx); - ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); @@ -8004,6 +8012,18 @@ static int llama_decode_internal( } } + // decide if we need to defrag the kv cache + if (cparams.defrag_thold >= 0.0f) { + const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > cparams.defrag_thold) { + //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation); + + llama_kv_cache_defrag(kv_self); + } + } + #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) // requires GGML_PERF to be defined @@ -8095,12 +8115,16 @@ static int llama_decode_internal( static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { auto & kv_self = lctx.kv_self; + const auto & hparams = lctx.model.hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_kv = llama_kv_cache_cell_max(kv_self); const uint32_t n_used = kv_self.used; assert(n_used <= n_kv); - const int64_t t_start = ggml_time_us(); + //const int64_t t_start = ggml_time_us(); // number of cells moved uint32_t n_moves = 0; @@ -8124,15 +8148,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // found a hole - fill it with data from the end of the cache - // determine the size of the hole uint32_t nh = 1; + + // determine the size of the hole while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) { nh++; } - // starting from the end, find nh non-empty cells + // each move requires 6*n_layer tensors (see build_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + // + if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) { + // the graph is too big, we cannot move more cells + break; + } + uint32_t nf = 0; uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells for (; is > i0; --is) { const auto & cell1 = kv_self.cells[is]; @@ -8153,11 +8188,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { nf = 0; + uint32_t i1 = is; + + // are we moving a continuous block of memory? + bool cont = false; + // go back and move the nf cells to the hole - for (uint32_t i1 = is; i1 < n_kv; ++i1) { - const auto & cell1 = kv_self.cells[i1]; + for (; i1 < n_kv; ++i1) { + auto & cell1 = kv_self.cells[i1]; if (cell1.is_empty() || ids[i1] != n_kv) { + cont = false; continue; } @@ -8167,11 +8208,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // move the cell meta data kv_self.cells[i0 + nf] = cell1; - n_moves++; + // clear the old cell and move the head there + cell1 = llama_kv_cell(); + kv_self.head = n_used; + + if (!cont) { + n_moves++; + cont = true; + } + nf++; + + if (nf == nh) { + break; + } } - LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh); + //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); i0 += nh - 1; } @@ -8180,15 +8233,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { return; } - LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves); + //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves); - kv_self.head = n_used; - kv_self.used = n_used; - - // zero the rest of the cells - for (uint32_t i = n_used; i < n_kv; ++i) { - kv_self.cells[i] = llama_kv_cell(); - } + //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer); #if 0 // CPU defrag @@ -8200,9 +8247,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // likely not worth the effort, as we have ggml_graph based defrag // - const auto & hparams = lctx.model.hparams; - - const uint32_t n_layer = hparams.n_layer; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -8271,9 +8315,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { llama_graph_compute(lctx, gf, lctx.cparams.n_threads); #endif - const int64_t t_end = ggml_time_us(); + //const int64_t t_end = ggml_time_us(); - LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0); + //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0); } static void llama_kv_cache_update_internal(struct llama_context & lctx) { @@ -10761,31 +10805,47 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) { new_type = GGML_TYPE_Q8_0; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) { new_type = GGML_TYPE_Q5_K; } else if (new_type != GGML_TYPE_Q8_0) { new_type = GGML_TYPE_Q6_K; } } else if (name == "token_embd.weight") { - if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { new_type = GGML_TYPE_Q2_K; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { - new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) { + new_type = GGML_TYPE_IQ3_S; } - } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ3_S; + } + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) { if (name.find("attn_v.weight") != std::string::npos) { if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K; - else new_type = GGML_TYPE_Q2_K; + else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } + else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } else if (name.find("ffn_down") != std::string::npos) { - if (qs.i_ffn_down < qs.n_ffn_down/8) new_type = GGML_TYPE_Q2_K; + if (qs.i_ffn_down < qs.n_ffn_down/8) { + new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + } ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) new_type = GGML_TYPE_IQ2_XXS; + if (qs.model.hparams.n_expert == 8) { + new_type = GGML_TYPE_Q5_K; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) new_type = GGML_TYPE_IQ2_XXS; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S; + } } } else if (name.find("attn_v.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { @@ -10795,7 +10855,13 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { - new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_Q3_K : GGML_TYPE_IQ3_XXS; + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_S && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { + new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_S && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q4_K; @@ -10807,7 +10873,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL && qs.model.hparams.n_gqa() >= 4) { + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q5_K; } else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && @@ -10833,13 +10899,19 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty // TODO: explore better strategies new_type = GGML_TYPE_Q8_0; } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ2_S; + } } else if (name.find("attn_q.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ2_S; + } } else if (name.find("ffn_down") != std::string::npos) { auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); int i_layer = info.first, n_layer = info.second; @@ -10870,8 +10942,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; } } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL && !qs.has_imatrix) { - if (i_layer < n_layer/8) new_type = GGML_TYPE_Q5_K; + else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) { + new_type = GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) { @@ -10888,15 +10960,15 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty } else if (name.find("attn_output.weight") != std::string::npos) { if (arch != LLM_ARCH_FALCON) { if (qs.model.hparams.n_expert == 8) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || - ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { + ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) { new_type = GGML_TYPE_Q5_K; } } else { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K; @@ -10915,7 +10987,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty else if (name.find("ffn_gate") != std::string::npos) { auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str()); int i_layer = info.first, n_layer = info.second; - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { new_type = GGML_TYPE_IQ3_XXS; } ++qs.i_ffn_gate; @@ -10923,7 +10995,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty else if (name.find("ffn_up") != std::string::npos) { auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str()); int i_layer = info.first, n_layer = info.second; - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { new_type = GGML_TYPE_IQ3_XXS; } ++qs.i_ffn_up; @@ -10942,8 +11014,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty //} bool convert_incompatible_tensor = false; if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || - new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || - new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || + new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || new_type == GGML_TYPE_IQ3_S) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; @@ -10958,14 +11030,16 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty switch (new_type) { case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); } LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); @@ -10991,7 +11065,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; - case LLAMA_FTYPE_MOSTLY_Q3_K_XS: quantized_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: quantized_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; @@ -11002,9 +11076,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: quantized_type = GGML_TYPE_IQ2_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS: quantized_type = GGML_TYPE_IQ2_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_S: quantized_type = GGML_TYPE_IQ2_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_M: quantized_type = GGML_TYPE_IQ2_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: quantized_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: quantized_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: quantized_type = GGML_TYPE_IQ3_S; break; @@ -11180,6 +11257,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } if ((new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS || + new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ1_S || (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { LLAMA_LOG_ERROR("\n\n============================================================\n"); @@ -11635,6 +11713,7 @@ struct llama_context_params llama_context_default_params() { /*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_slow =*/ 1.0f, /*.yarn_orig_ctx =*/ 0, + /*.defrag_thold =*/ -1.0f, /*.cb_eval =*/ nullptr, /*.cb_eval_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, @@ -11799,6 +11878,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_attn_factor = params.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.defrag_thold = params.defrag_thold; cparams.mul_mat_q = params.mul_mat_q; cparams.offload_kqv = params.offload_kqv; cparams.do_pooling = params.do_pooling; @@ -12000,7 +12080,7 @@ struct llama_context * llama_new_context_with_model( } // buffer used to store the computation graph and the tensor meta data - ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); + ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false)); ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES); diff --git a/llama.h b/llama.h index ff131996d..16e28e91d 100644 --- a/llama.h +++ b/llama.h @@ -107,12 +107,15 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -243,6 +246,7 @@ extern "C" { float yarn_beta_fast; // YaRN low correction dim float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size + float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default) ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 24d12ef14..d4cea805f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1916,9 +1916,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, - GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, - GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, + GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, }; // unary ops diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 04656bb9e..f615b612d 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -150,6 +150,7 @@ int main(int argc, char * argv[]) { const float total_error = total_quantization_error(qfns, test_size, test_data.data()); const float max_quantization_error = type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : + type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR; @@ -168,7 +169,8 @@ int main(int argc, char * argv[]) { const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data()); const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || - type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT + type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S + ? MAX_DOT_PRODUCT_ERROR_LOWBIT : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error); num_failed += failed;