Merge branch 'master' into xsn/server_twice_ctrl_c
This commit is contained in:
commit
66c7480a42
19 changed files with 2716 additions and 245 deletions
9
Makefile
9
Makefile
|
@ -381,8 +381,13 @@ ifdef LLAMA_BLIS
|
|||
endif # LLAMA_BLIS
|
||||
|
||||
ifdef LLAMA_CUBLAS
|
||||
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
|
||||
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
|
||||
ifneq ('', '$(wildcard /opt/cuda)')
|
||||
CUDA_PATH ?= /opt/cuda
|
||||
else
|
||||
CUDA_PATH ?= /usr/local/cuda
|
||||
endif
|
||||
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
|
||||
OBJS += ggml-cuda.o
|
||||
MK_NVCCFLAGS += -use_fast_math
|
||||
ifdef LLAMA_FATAL_WARNINGS
|
||||
|
|
|
@ -159,6 +159,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
|
|||
- [withcatai/catai](https://github.com/withcatai/catai)
|
||||
- [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT)
|
||||
- [Msty](https://msty.app) (proprietary)
|
||||
- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT)
|
||||
|
||||
---
|
||||
|
||||
|
|
|
@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
} else if (arg == "--defrag-thold" || arg == "-dt") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.defrag_thold = std::stof(argv[i]);
|
||||
} else if (arg == "--samplers") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" -dt N, --defrag-thold N\n");
|
||||
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
|
||||
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
||||
printf(" --no-penalize-nl do not penalize newline token\n");
|
||||
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
|
||||
|
@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
|
|
|
@ -75,6 +75,7 @@ struct gpt_params {
|
|||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
||||
|
||||
|
|
|
@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_defrag (ctx);
|
||||
//llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
@ -213,7 +213,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_defrag (ctx);
|
||||
//llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
|
|
@ -23,18 +23,21 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
|||
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
|
||||
{ "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", },
|
||||
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
|
||||
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
|
||||
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
|
||||
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
|
||||
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
|
||||
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
|
||||
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
|
||||
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
|
||||
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
|
||||
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
|
||||
{ "Q3_K_XS",LLAMA_FTYPE_MOSTLY_Q3_K_XS,"3-bit extra small quantization" , },
|
||||
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization" , },
|
||||
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
|
||||
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.25 bpw non-linear quantization", },
|
||||
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
|
||||
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
|
||||
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
|
||||
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
|
||||
|
@ -292,6 +295,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
|
||||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S ||
|
||||
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) {
|
||||
fprintf(stderr, "\n===============================================================================================\n");
|
||||
fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
|
||||
|
|
|
@ -1336,6 +1336,10 @@ struct llama_server_context
|
|||
split_multiprompt_task(task_id, task);
|
||||
}
|
||||
} else {
|
||||
// an empty prompt can make slot become buggy
|
||||
if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get<std::string>().empty()) {
|
||||
task.data["prompt"] = " "; // add a space so that we have one token
|
||||
}
|
||||
queue_tasks.post(task);
|
||||
}
|
||||
}
|
||||
|
|
556
ggml-cuda.cu
556
ggml-cuda.cu
|
@ -523,6 +523,17 @@ typedef struct {
|
|||
} block_iq2_xs;
|
||||
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
||||
|
||||
// 2.5625 bpw quants
|
||||
#define QR2_S 8
|
||||
#define QI2_S (QK_K / (4*QR2_S))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK_K/4];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq2_s;
|
||||
static_assert(sizeof(block_iq2_s) == sizeof(ggml_fp16_t) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
|
||||
|
||||
#define QR3_XXS 8
|
||||
#define QI3_XXS (QK_K / (4*QR3_XXS))
|
||||
typedef struct {
|
||||
|
@ -560,6 +571,18 @@ typedef struct {
|
|||
} block_iq4_nl;
|
||||
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
||||
|
||||
// QR4_XS = 8 is very slightly faster than QR4_XS = 4
|
||||
#define QR4_XS 8
|
||||
#define QI4_XS (QK_K / (4*QR4_XS))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint16_t scales_h;
|
||||
uint8_t scales_l[QK_K/64];
|
||||
uint8_t qs[QK_K/2];
|
||||
} block_iq4_xs;
|
||||
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
||||
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
|
@ -685,18 +708,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|||
return a;
|
||||
}
|
||||
|
||||
//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
//#pragma unroll
|
||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||
// }
|
||||
// return a;
|
||||
//#else
|
||||
// (void) a;
|
||||
// NO_DEVICE_CODE;
|
||||
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
//}
|
||||
#ifdef GGML_CUDA_F16
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||
}
|
||||
return a;
|
||||
#else
|
||||
(void) a;
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
#endif // GGML_CUDA_F16
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
|
@ -1689,6 +1714,265 @@ static const __device__ uint64_t iq2xs_grid[512] = {
|
|||
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
||||
};
|
||||
|
||||
static const __device__ uint64_t iq2s_grid[1024] = {
|
||||
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
||||
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
||||
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
||||
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
||||
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
|
||||
0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
|
||||
0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
|
||||
0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
|
||||
0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
|
||||
0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
|
||||
0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
|
||||
0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
|
||||
0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
|
||||
0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
|
||||
0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
|
||||
0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
|
||||
0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
|
||||
0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
|
||||
0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
|
||||
0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
|
||||
0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
|
||||
0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
|
||||
0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
|
||||
0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
|
||||
0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
|
||||
0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
|
||||
0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
|
||||
0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
|
||||
0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
|
||||
0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
|
||||
0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
|
||||
0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
|
||||
0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
|
||||
0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
|
||||
0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
|
||||
0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
|
||||
0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
|
||||
0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
|
||||
0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
|
||||
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
|
||||
0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
|
||||
0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
|
||||
0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
|
||||
0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
|
||||
0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
|
||||
0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
|
||||
0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
|
||||
0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
|
||||
0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
|
||||
0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
|
||||
0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
|
||||
0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
|
||||
0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
|
||||
0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
|
||||
0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
|
||||
0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
|
||||
0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
|
||||
0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
|
||||
0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
|
||||
0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
|
||||
0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
|
||||
0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
|
||||
0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
|
||||
0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
|
||||
0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
|
||||
0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
|
||||
0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
|
||||
0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
|
||||
0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
|
||||
0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
|
||||
0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
|
||||
0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
|
||||
0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
|
||||
0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
|
||||
0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
|
||||
0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
|
||||
0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
|
||||
0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
|
||||
0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
|
||||
0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
|
||||
0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
|
||||
0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
|
||||
0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
|
||||
0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
|
||||
0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
|
||||
0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
|
||||
0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
|
||||
0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
|
||||
0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
|
||||
0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
|
||||
0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
|
||||
0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
|
||||
0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
|
||||
0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
|
||||
0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
|
||||
0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
|
||||
0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
|
||||
0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
|
||||
0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
|
||||
0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
|
||||
0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
|
||||
0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
|
||||
0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
|
||||
0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
|
||||
0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
|
||||
0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
|
||||
0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
|
||||
0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
|
||||
0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
|
||||
0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
|
||||
0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
|
||||
0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
|
||||
0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
|
||||
0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
|
||||
0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
|
||||
0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
|
||||
0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
|
||||
0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
|
||||
0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
|
||||
0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
|
||||
0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
|
||||
0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
|
||||
0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
|
||||
0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
|
||||
0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
|
||||
0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
|
||||
0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
|
||||
0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
|
||||
0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
|
||||
0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
|
||||
0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
|
||||
0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
|
||||
0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
|
||||
0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
|
||||
0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
|
||||
0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
|
||||
0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
|
||||
0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
|
||||
0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
|
||||
0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
|
||||
0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
|
||||
0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
|
||||
0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
|
||||
0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
|
||||
0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
|
||||
0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
|
||||
0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
|
||||
0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
|
||||
0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
|
||||
0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
|
||||
0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
|
||||
0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
|
||||
0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
|
||||
0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
|
||||
0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
|
||||
0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
|
||||
0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
|
||||
0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
|
||||
0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
|
||||
0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
|
||||
0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
|
||||
0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
|
||||
0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
|
||||
0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
|
||||
0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
|
||||
0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
|
||||
0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
|
||||
0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
|
||||
0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
|
||||
0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
|
||||
0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
|
||||
0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
|
||||
0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
|
||||
0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
|
||||
0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
|
||||
0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
|
||||
0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
|
||||
0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
|
||||
0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
|
||||
0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
|
||||
0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
|
||||
0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
|
||||
0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
|
||||
0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
|
||||
0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
|
||||
0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
|
||||
0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
|
||||
0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
|
||||
0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
|
||||
0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
|
||||
0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
|
||||
0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
|
||||
0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
|
||||
0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
|
||||
0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
|
||||
0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
|
||||
0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
|
||||
0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
|
||||
0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
|
||||
0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
|
||||
0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
|
||||
0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
|
||||
0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
|
||||
0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
|
||||
0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
|
||||
0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
|
||||
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
|
||||
0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
|
||||
0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
|
||||
0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
|
||||
0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
|
||||
0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
|
||||
0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
|
||||
0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
|
||||
0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
|
||||
0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
|
||||
0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
|
||||
0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
|
||||
0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
|
||||
0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
|
||||
0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
|
||||
0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
|
||||
0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
|
||||
0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
|
||||
0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
|
||||
0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
|
||||
0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
|
||||
0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
|
||||
0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
|
||||
0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
|
||||
0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
|
||||
0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
|
||||
0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
|
||||
0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
|
||||
0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
|
||||
0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
|
||||
0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
|
||||
0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
|
||||
0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
|
||||
0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
|
||||
0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
|
||||
0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
|
||||
0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
|
||||
0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
|
||||
0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
|
||||
0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
|
||||
0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
|
||||
0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
|
||||
0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
|
||||
0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
|
||||
0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
|
||||
0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
|
||||
0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
|
||||
0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
|
||||
0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
|
||||
};
|
||||
|
||||
static const __device__ uint32_t iq3xxs_grid[256] = {
|
||||
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
||||
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
||||
|
@ -2037,6 +2321,27 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
|||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
|
@ -2134,6 +2439,25 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
|||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
@ -2230,10 +2554,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2334,10 +2655,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2470,10 +2788,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2586,10 +2901,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2696,10 +3008,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2734,11 +3043,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
|
||||
}
|
||||
amax = warp_reduce_max(amax);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
const float d = amax / 127;
|
||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
@ -4800,6 +5106,54 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
|||
#endif
|
||||
}
|
||||
|
||||
// TODO
|
||||
static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
#if QK_K == 256
|
||||
const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
|
||||
|
||||
const int ib32 = iqs;
|
||||
const int8_t * q8 = bq8_1[ib32].qs;
|
||||
const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
|
||||
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
|
||||
const uint8_t ls2 = bq2->scales[ib32] >> 4;
|
||||
int sumi1 = 0;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
|
||||
const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
||||
const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
||||
const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
|
||||
const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
|
||||
sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
|
||||
sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
|
||||
q8 += 8;
|
||||
}
|
||||
int sumi2 = 0;
|
||||
for (int l = 2; l < 4; ++l) {
|
||||
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
|
||||
const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
||||
const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
||||
const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
|
||||
const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
|
||||
sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
|
||||
sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
|
||||
q8 += 8;
|
||||
}
|
||||
const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
|
||||
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
||||
#else
|
||||
(void) ksigns64;
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
#else
|
||||
(void) ksigns64;
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
|
@ -4963,6 +5317,76 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
|||
return d * (sumi1 + sumi2);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
|
||||
#if QK_K == 256
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
|
||||
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
|
||||
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
||||
|
||||
//// iqs is 0...7
|
||||
//const int ib64 = iqs/2;
|
||||
//const int il = iqs%2;
|
||||
//const int32_t * q8_1 = (const int *)bq8_1[2*ib64+0].qs + 2*il;
|
||||
//const int32_t * q8_2 = (const int *)bq8_1[2*ib64+1].qs + 2*il;
|
||||
//const uint32_t * q4_1 = (const uint32_t *)bq4->qs + 8*ib64 + 2*il;
|
||||
//const uint32_t * q4_2 = q4_1 + 4;
|
||||
//const int8_t ls1 = (bq4->scales_l[ib64] & 0xf) | (((bq4->scales_h >> (4*ib64+0)) & 3) << 4);
|
||||
//const int8_t ls2 = (bq4->scales_l[ib64] >> 4) | (((bq4->scales_h >> (4*ib64+2)) & 3) << 4);
|
||||
//const float d1 = (float)bq4->d * (ls1 - 32) * __low2float(bq8_1[2*ib64+0].ds);
|
||||
//const float d2 = (float)bq4->d * (ls2 - 32) * __low2float(bq8_1[2*ib64+1].ds);
|
||||
//int v1, v2;
|
||||
//int sumi1 = 0, sumi2 = 0;
|
||||
//for (int j = 0; j < 2; ++j) {
|
||||
// get_int_from_table_16(q4_1[j], values, v1, v2);
|
||||
// sumi1 = __dp4a(v2, q8_1[j+4], __dp4a(v1, q8_1[j+0], sumi1));
|
||||
// get_int_from_table_16(q4_2[j], values, v1, v2);
|
||||
// sumi2 = __dp4a(v2, q8_2[j+4], __dp4a(v1, q8_2[j+0], sumi2));
|
||||
//}
|
||||
//return d1 * sumi1 + d2 * sumi2;
|
||||
|
||||
// iqs is 0...7
|
||||
const int ib32 = iqs;
|
||||
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
|
||||
const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
|
||||
const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
|
||||
int v1, v2;
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
get_int_from_table_16(q4[j], values, v1, v2);
|
||||
sumi1 = __dp4a(v1, q8[j+0], sumi1);
|
||||
sumi2 = __dp4a(v2, q8[j+4], sumi2);
|
||||
}
|
||||
return d * (sumi1 + sumi2);
|
||||
|
||||
//// iqs is 0...15
|
||||
//const int ib32 = iqs/2;
|
||||
//const int il = iqs%2;
|
||||
//const int32_t * q8 = (const int *)bq8_1[ib32].qs + 2*il;
|
||||
//const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il;
|
||||
//const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
|
||||
//const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
|
||||
//int v1, v2;
|
||||
//int sumi1 = 0, sumi2 = 0;
|
||||
//for (int j = 0; j < 2; ++j) {
|
||||
// get_int_from_table_16(q4[j], values, v1, v2);
|
||||
// sumi1 = __dp4a(v1, q8[j+0], sumi1);
|
||||
// sumi2 = __dp4a(v2, q8[j+4], sumi2);
|
||||
//}
|
||||
//return d * (sumi1 + sumi2);
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void mul_mat_q(
|
||||
|
@ -5883,10 +6307,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
|
|||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
#ifdef GGML_CUDA_F16
|
||||
|
@ -5936,10 +6357,7 @@ static __global__ void mul_mat_p021_f16_f32(
|
|||
const int idst = channel*nrows_dst + row_dst;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[idst] = tmp;
|
||||
|
@ -5982,10 +6400,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[idst] = tmp;
|
||||
|
@ -6996,6 +7411,12 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k,
|
|||
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
|
@ -7020,6 +7441,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
|
|||
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
|
@ -7057,12 +7484,16 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
return dequantize_row_iq2_xs_cuda;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
return dequantize_row_iq2_s_cuda;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
return dequantize_row_iq3_xxs_cuda;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_cuda;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return dequantize_row_iq4_nl_cuda;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return dequantize_row_iq4_xs_cuda;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_F32:
|
||||
|
@ -7098,12 +7529,16 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
return dequantize_row_iq2_xs_cuda;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
return dequantize_row_iq2_s_cuda;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
return dequantize_row_iq3_xxs_cuda;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_cuda;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return dequantize_row_iq4_nl_cuda;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return dequantize_row_iq4_xs_cuda;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
|
@ -8079,8 +8514,8 @@ static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual
|
|||
*actual_size = look_ahead_size;
|
||||
g_cuda_pool_size[device] += look_ahead_size;
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
|
||||
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
|
||||
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
|
||||
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[device]/1024/1024), (uint32_t)(size/1024/1024));
|
||||
#endif
|
||||
return ptr;
|
||||
}
|
||||
|
@ -8166,7 +8601,7 @@ static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual
|
|||
g_cuda_pool_used[device] += size;
|
||||
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
|
||||
printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
|
||||
#endif
|
||||
|
||||
return ptr;
|
||||
|
@ -8176,7 +8611,7 @@ static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) {
|
|||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
|
||||
printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
|
||||
#endif
|
||||
|
||||
g_cuda_pool_used[device] -= size;
|
||||
|
@ -8848,9 +9283,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
||||
default:
|
||||
|
@ -8874,9 +9311,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
||||
case GGML_TYPE_Q6_K:
|
||||
|
@ -8971,6 +9410,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|||
mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
|
@ -8983,6 +9426,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|||
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
|
@ -11710,7 +12157,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||
}
|
||||
ggml_type a_type = a->type;
|
||||
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
||||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S) {
|
||||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
|
||||
a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
|
||||
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
||||
return false;
|
||||
}
|
||||
|
|
66
ggml-metal.m
66
ggml-metal.m
|
@ -62,8 +62,10 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
|
@ -87,8 +89,10 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
||||
|
@ -108,8 +112,10 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
||||
|
@ -126,8 +132,10 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
||||
|
@ -144,8 +152,10 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
||||
|
@ -458,8 +468,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
||||
|
@ -483,8 +495,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
||||
|
@ -504,8 +518,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
||||
|
@ -522,8 +538,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
||||
|
@ -540,8 +558,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
||||
|
@ -1358,8 +1378,10 @@ static bool ggml_metal_graph_compute(
|
|||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||
}
|
||||
|
||||
|
@ -1500,6 +1522,12 @@ static bool ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
|
@ -1512,6 +1540,12 @@ static bool ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
||||
|
@ -1544,9 +1578,9 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
|
@ -1559,7 +1593,7 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ4_NL) {
|
||||
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
||||
const int mem_size = 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
|
@ -1658,8 +1692,10 @@ static bool ggml_metal_graph_compute(
|
|||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||
}
|
||||
|
||||
|
@ -1803,6 +1839,12 @@ static bool ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
|
@ -1815,6 +1857,12 @@ static bool ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
||||
|
@ -1863,9 +1911,9 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
||||
}
|
||||
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
||||
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
||||
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
||||
|
@ -1878,7 +1926,7 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ4_NL) {
|
||||
else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
|
||||
const int mem_size = 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
|
@ -1925,8 +1973,10 @@ static bool ggml_metal_graph_compute(
|
|||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
||||
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
|
711
ggml-metal.metal
711
ggml-metal.metal
|
@ -2519,6 +2519,14 @@ typedef struct {
|
|||
} block_iq2_xs;
|
||||
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
||||
|
||||
// 2.5625 bpw quants
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK_K/4];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq2_s;
|
||||
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[3*QK_K/8];
|
||||
|
@ -2552,6 +2560,13 @@ typedef struct {
|
|||
uint8_t qs[QK4_NL/2];
|
||||
} block_iq4_nl;
|
||||
|
||||
typedef struct {
|
||||
half d;
|
||||
uint16_t scales_h;
|
||||
uint8_t scales_l[QK_K/64];
|
||||
uint8_t qs[QK_K/2];
|
||||
} block_iq4_xs;
|
||||
|
||||
//====================================== dot products =========================
|
||||
|
||||
void kernel_mul_mv_q2_K_f32_impl(
|
||||
|
@ -3774,6 +3789,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
|
|||
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
||||
};
|
||||
|
||||
constexpr constant static uint64_t iq2s_grid[1024] = {
|
||||
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
||||
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
||||
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
||||
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
||||
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
|
||||
0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
|
||||
0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
|
||||
0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
|
||||
0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
|
||||
0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
|
||||
0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
|
||||
0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
|
||||
0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
|
||||
0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
|
||||
0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
|
||||
0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
|
||||
0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
|
||||
0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
|
||||
0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
|
||||
0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
|
||||
0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
|
||||
0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
|
||||
0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
|
||||
0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
|
||||
0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
|
||||
0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
|
||||
0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
|
||||
0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
|
||||
0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
|
||||
0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
|
||||
0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
|
||||
0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
|
||||
0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
|
||||
0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
|
||||
0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
|
||||
0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
|
||||
0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
|
||||
0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
|
||||
0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
|
||||
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
|
||||
0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
|
||||
0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
|
||||
0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
|
||||
0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
|
||||
0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
|
||||
0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
|
||||
0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
|
||||
0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
|
||||
0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
|
||||
0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
|
||||
0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
|
||||
0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
|
||||
0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
|
||||
0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
|
||||
0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
|
||||
0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
|
||||
0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
|
||||
0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
|
||||
0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
|
||||
0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
|
||||
0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
|
||||
0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
|
||||
0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
|
||||
0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
|
||||
0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
|
||||
0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
|
||||
0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
|
||||
0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
|
||||
0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
|
||||
0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
|
||||
0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
|
||||
0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
|
||||
0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
|
||||
0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
|
||||
0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
|
||||
0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
|
||||
0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
|
||||
0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
|
||||
0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
|
||||
0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
|
||||
0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
|
||||
0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
|
||||
0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
|
||||
0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
|
||||
0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
|
||||
0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
|
||||
0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
|
||||
0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
|
||||
0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
|
||||
0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
|
||||
0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
|
||||
0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
|
||||
0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
|
||||
0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
|
||||
0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
|
||||
0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
|
||||
0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
|
||||
0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
|
||||
0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
|
||||
0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
|
||||
0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
|
||||
0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
|
||||
0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
|
||||
0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
|
||||
0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
|
||||
0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
|
||||
0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
|
||||
0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
|
||||
0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
|
||||
0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
|
||||
0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
|
||||
0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
|
||||
0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
|
||||
0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
|
||||
0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
|
||||
0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
|
||||
0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
|
||||
0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
|
||||
0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
|
||||
0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
|
||||
0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
|
||||
0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
|
||||
0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
|
||||
0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
|
||||
0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
|
||||
0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
|
||||
0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
|
||||
0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
|
||||
0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
|
||||
0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
|
||||
0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
|
||||
0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
|
||||
0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
|
||||
0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
|
||||
0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
|
||||
0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
|
||||
0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
|
||||
0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
|
||||
0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
|
||||
0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
|
||||
0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
|
||||
0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
|
||||
0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
|
||||
0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
|
||||
0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
|
||||
0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
|
||||
0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
|
||||
0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
|
||||
0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
|
||||
0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
|
||||
0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
|
||||
0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
|
||||
0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
|
||||
0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
|
||||
0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
|
||||
0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
|
||||
0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
|
||||
0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
|
||||
0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
|
||||
0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
|
||||
0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
|
||||
0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
|
||||
0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
|
||||
0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
|
||||
0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
|
||||
0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
|
||||
0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
|
||||
0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
|
||||
0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
|
||||
0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
|
||||
0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
|
||||
0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
|
||||
0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
|
||||
0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
|
||||
0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
|
||||
0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
|
||||
0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
|
||||
0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
|
||||
0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
|
||||
0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
|
||||
0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
|
||||
0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
|
||||
0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
|
||||
0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
|
||||
0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
|
||||
0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
|
||||
0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
|
||||
0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
|
||||
0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
|
||||
0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
|
||||
0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
|
||||
0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
|
||||
0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
|
||||
0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
|
||||
0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
|
||||
0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
|
||||
0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
|
||||
0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
|
||||
0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
|
||||
0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
|
||||
0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
|
||||
0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
|
||||
0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
|
||||
0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
|
||||
0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
|
||||
0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
|
||||
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
|
||||
0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
|
||||
0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
|
||||
0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
|
||||
0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
|
||||
0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
|
||||
0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
|
||||
0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
|
||||
0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
|
||||
0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
|
||||
0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
|
||||
0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
|
||||
0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
|
||||
0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
|
||||
0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
|
||||
0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
|
||||
0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
|
||||
0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
|
||||
0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
|
||||
0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
|
||||
0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
|
||||
0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
|
||||
0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
|
||||
0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
|
||||
0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
|
||||
0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
|
||||
0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
|
||||
0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
|
||||
0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
|
||||
0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
|
||||
0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
|
||||
0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
|
||||
0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
|
||||
0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
|
||||
0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
|
||||
0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
|
||||
0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
|
||||
0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
|
||||
0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
|
||||
0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
|
||||
0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
|
||||
0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
|
||||
0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
|
||||
0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
|
||||
0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
|
||||
0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
|
||||
0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
|
||||
0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
|
||||
0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
|
||||
};
|
||||
|
||||
constexpr constant static uint32_t iq3xxs_grid[256] = {
|
||||
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
||||
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
||||
|
@ -4572,6 +4846,139 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|||
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq2_s_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const int ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
//threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
||||
//{
|
||||
// int nval = 32;
|
||||
// int pos = (32*sgitg + tiisg)*nval;
|
||||
// for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
|
||||
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//}
|
||||
|
||||
const int ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
}
|
||||
|
||||
const int ibl = ib32 / (QK_K / 32);
|
||||
const int ib = ib32 % (QK_K / 32);
|
||||
|
||||
device const block_iq2_s * xr = x + ibl;
|
||||
device const uint8_t * qs = xr->qs + 4 * ib;
|
||||
device const uint8_t * qh = xr->qh + ib;
|
||||
device const uint8_t * sc = xr->scales + ib;
|
||||
device const uint8_t * signs = qs + QK_K/8;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
const float db = dh[0];
|
||||
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
||||
const float d2 = db * (0.5f + (sc[0] >> 4));
|
||||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
||||
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
||||
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
||||
}
|
||||
}
|
||||
sumf[row] += d1 * sum[0] + d2 * sum[1];
|
||||
|
||||
dh += nb*sizeof(block_iq2_s)/2;
|
||||
qs += nb*sizeof(block_iq2_s);
|
||||
qh += nb*sizeof(block_iq2_s);
|
||||
sc += nb*sizeof(block_iq2_s);
|
||||
signs += nb*sizeof(block_iq2_s);
|
||||
}
|
||||
|
||||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq2_s_f32")]]
|
||||
kernel void kernel_mul_mv_iq2_s_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq1_s_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
|
@ -4760,6 +5167,100 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|||
}
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup float * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
const int first_row = (r0 * 2 + sgitg) * 2;
|
||||
const int ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
const int ix = tiisg/16; // 0 or 1
|
||||
const int it = tiisg%16; // 0...15
|
||||
const int ib = it/2;
|
||||
const int il = it%2;
|
||||
|
||||
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[2]={0.f}, all_sum;
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
device const block_iq4_xs & xb = x[row*nb + ibl];
|
||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
|
||||
aux32[0] = q4[0] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
acc2 += yl[1] * qf2;
|
||||
|
||||
aux32[0] = q4[1] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
acc1 += acc2;
|
||||
|
||||
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
||||
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||
|
||||
}
|
||||
|
||||
yb += 2 * QK_K;
|
||||
}
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
||||
kernel void kernel_mul_mv_iq1_s_f32(
|
||||
device const void * src0,
|
||||
|
@ -4817,6 +5318,35 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|||
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
||||
kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup float * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
//============================= templates and their specializations =============================
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
|
@ -5188,6 +5718,25 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
|
|||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint8_t * signs = qs + QK_K/8;
|
||||
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
||||
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
|
||||
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
|
@ -5219,6 +5768,26 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
|
||||
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
|
||||
const float d = (float)xb->d * (ls - 32);
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
|
||||
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
||||
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
||||
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
kernel void kernel_get_rows(
|
||||
device const void * src0,
|
||||
|
@ -5762,8 +6331,10 @@ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_r
|
|||
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// matrix-matrix multiplication
|
||||
|
@ -5804,8 +6375,10 @@ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_m
|
|||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
|
@ -5858,8 +6431,10 @@ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel
|
|||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
|
@ -6893,6 +7468,71 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|||
sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq2_s_f32(
|
||||
device const char * ids,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint64_t & nb1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
|
||||
kernel_mul_mv_iq2_s_f32_impl(
|
||||
src0[id],
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
ne10,
|
||||
ne12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3,
|
||||
shared_values,
|
||||
tgpig,
|
||||
tiisg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq1_s_f32(
|
||||
device const char * ids,
|
||||
|
@ -7020,3 +7660,68 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|||
tiisg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
||||
device const char * ids,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint64_t & nb1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup float * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
|
||||
kernel_mul_mv_iq4_xs_f32_impl(
|
||||
src0[id],
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
ne10,
|
||||
ne12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3,
|
||||
shared_values,
|
||||
tgpig,
|
||||
tiisg,
|
||||
sgitg);
|
||||
}
|
||||
|
|
1042
ggml-quants.c
1042
ggml-quants.c
File diff suppressed because it is too large
Load diff
|
@ -182,6 +182,15 @@ typedef struct {
|
|||
} block_iq2_xs;
|
||||
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
||||
|
||||
// 2.5625 bpw quants
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint8_t qs[QK_K/4];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq2_s;
|
||||
static_assert(sizeof(block_iq2_s) == sizeof(ggml_fp16_t) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
|
||||
|
||||
// (Almost) "true" 3-bit quantization.
|
||||
// Due to the need to use blocks as per ggml design, it ends up using
|
||||
// 3.0625 bpw because of the 16-bit scale for each block of 256.
|
||||
|
@ -221,6 +230,14 @@ typedef struct {
|
|||
} block_iq4_nl;
|
||||
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint16_t scales_h;
|
||||
uint8_t scales_l[QK_K/64];
|
||||
uint8_t qs[QK_K/2];
|
||||
} block_iq4_xs;
|
||||
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -241,7 +258,9 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
|
|||
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k);
|
||||
|
||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
|
@ -258,7 +277,9 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
|||
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
||||
|
||||
// Dequantization
|
||||
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
|
@ -276,9 +297,11 @@ void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRI
|
|||
void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||
|
||||
// Dot product
|
||||
|
@ -295,9 +318,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|||
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
//
|
||||
|
@ -305,9 +330,11 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
//
|
||||
size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_iq2_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_iq4_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
|
||||
|
|
248
ggml-sycl.cpp
248
ggml-sycl.cpp
|
@ -8126,23 +8126,51 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
|||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||
}
|
||||
|
||||
static void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale,
|
||||
const sycl::nd_item<3> &item_ct1, float *buf) {
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int rowx = item_ct1.get_group(2);
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||
|
||||
const int block_size = item_ct1.get_local_range(2);
|
||||
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
||||
|
||||
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = rowx/nrows_y; // head index
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = sycl::pow(base, float(exp));
|
||||
}
|
||||
|
||||
float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
|
||||
float max_val = -INFINITY;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
max_val = sycl::max(max_val, x[ix] * scale + (y ? y[iy] : 0.0f));
|
||||
|
||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = sycl::max(max_val, val);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
|
@ -8151,30 +8179,12 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
|
|||
if (warp_id == 0) {
|
||||
buf[lane_id] = -INFINITY;
|
||||
}
|
||||
/*
|
||||
DPCT1118:12: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
/*
|
||||
DPCT1065:60: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
||||
better performance if there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = max_val;
|
||||
}
|
||||
/*
|
||||
DPCT1118:13: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
/*
|
||||
DPCT1065:61: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
||||
better performance if there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
max_val = buf[lane_id];
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
|
@ -8182,13 +8192,16 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
|
|||
|
||||
float tmp = 0.f;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
const float val =
|
||||
sycl::native::exp((x[ix] * scale + (y ? y[iy] : 0.0f)) - max_val);
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = sycl::native::exp(vals[col] - max_val);
|
||||
tmp += val;
|
||||
dst[ix] = val;
|
||||
vals[col] = val;
|
||||
}
|
||||
|
||||
// find the sum of exps in the block
|
||||
|
@ -8197,40 +8210,29 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
|
|||
if (warp_id == 0) {
|
||||
buf[lane_id] = 0.f;
|
||||
}
|
||||
/*
|
||||
DPCT1118:14: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
/*
|
||||
DPCT1065:62: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
||||
better performance if there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = tmp;
|
||||
}
|
||||
/*
|
||||
DPCT1118:15: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
/*
|
||||
DPCT1065:63: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
||||
better performance if there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
tmp = buf[lane_id];
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float inv_tmp = 1.f / tmp;
|
||||
const float inv_sum = 1.f / tmp;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const int i = rowx*ncols + col;
|
||||
dst[i] *= inv_tmp;
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idst = rowx*ncols + col;
|
||||
dst[idst] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10867,35 +10869,96 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
|||
});
|
||||
}
|
||||
|
||||
static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int ncols_x, const int nrows_x,
|
||||
const int nrows_y, const float scale,
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||
const size_t n_local_scratch, dpct::queue_ptr stream) {
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
local_buf_acc.get_pointer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
|
||||
float * dst, const int ncols_x, const int nrows_x,
|
||||
const int nrows_y, const float scale, const float max_bias,
|
||||
dpct::queue_ptr stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const sycl::range<3> block_dims(1, 1, nth);
|
||||
const sycl::range<3> block_nums(1, 1, nrows_x);
|
||||
/*
|
||||
DPCT1049:46: The work-group size passed to the SYCL kernel may exceed the
|
||||
limit. To get the device limit, query info::device::max_work_group_size.
|
||||
Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
/*
|
||||
DPCT1101:96: 'SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE' expression was
|
||||
replaced with a value. Modify the code to use the original expression,
|
||||
provided in comments, if it is correct.
|
||||
*/
|
||||
sycl::local_accessor<float, 1> buf_acc_ct1(
|
||||
sycl::range<1>(32 /*SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE*/), cgh);
|
||||
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
||||
static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||
soft_max_f32(x, y, dst, ncols_x, nrows_y, scale, item_ct1,
|
||||
buf_acc_ct1.get_pointer());
|
||||
});
|
||||
});
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
||||
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, WARP_SIZE, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -12435,14 +12498,35 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
float scale = 1.0f;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
float max_bias = 0.0f;
|
||||
|
||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||
|
||||
(void) dst;
|
||||
// positions tensor
|
||||
float * src2_dd = nullptr;
|
||||
sycl_pool_alloc<float> src2_f;
|
||||
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
const bool use_src2 = src2 != nullptr;
|
||||
|
||||
if (use_src2) {
|
||||
const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
|
||||
|
||||
if (src2_on_device) {
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
|
||||
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
||||
} else {
|
||||
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
||||
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
||||
}
|
||||
}
|
||||
|
||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
|
||||
nrows_x, nrows_y, scale, max_bias, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
|
|
61
ggml.c
61
ggml.c
|
@ -690,6 +690,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ2_S] = {
|
||||
.type_name = "iq2_s",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_iq2_s),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_s,
|
||||
.from_float = quantize_row_iq2_s,
|
||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference,
|
||||
.vec_dot = ggml_vec_dot_iq2_s_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ1_S] = {
|
||||
.type_name = "iq1_s",
|
||||
.blck_size = QK_K,
|
||||
|
@ -714,6 +726,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ4_XS] = {
|
||||
.type_name = "iq4_xs",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_iq4_xs),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
||||
.from_float = quantize_row_iq4_xs,
|
||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
||||
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q8_K] = {
|
||||
.type_name = "q8_K",
|
||||
.blck_size = QK_K,
|
||||
|
@ -2316,7 +2340,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|||
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
|
||||
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
||||
}
|
||||
|
@ -7751,7 +7777,9 @@ static void ggml_compute_forward_add(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
ggml_compute_forward_add_q_f32(params, dst);
|
||||
} break;
|
||||
|
@ -8031,7 +8059,9 @@ static void ggml_compute_forward_add1(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
ggml_compute_forward_add1_q_f32(params, dst);
|
||||
} break;
|
||||
|
@ -8156,7 +8186,9 @@ static void ggml_compute_forward_acc(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -11055,7 +11087,9 @@ static void ggml_compute_forward_out_prod(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
ggml_compute_forward_out_prod_q_f32(params, dst);
|
||||
} break;
|
||||
|
@ -11244,7 +11278,9 @@ static void ggml_compute_forward_set(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -11447,7 +11483,9 @@ static void ggml_compute_forward_get_rows(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
ggml_compute_forward_get_rows_q(params, dst);
|
||||
} break;
|
||||
|
@ -12148,7 +12186,9 @@ static void ggml_compute_forward_alibi(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
|
@ -12232,7 +12272,9 @@ static void ggml_compute_forward_clamp(
|
|||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
|
@ -19482,6 +19524,7 @@ void ggml_quantize_init(enum ggml_type type) {
|
|||
switch (type) {
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
|
||||
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
||||
case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
|
||||
|
@ -19768,6 +19811,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|||
result = quantize_iq3_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
||||
GGML_ASSERT(result == row_size * nrows);
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
GGML_ASSERT(start % QK_K == 0);
|
||||
GGML_ASSERT(start % n_per_row == 0);
|
||||
size_t start_row = start / n_per_row;
|
||||
size_t row_size = ggml_row_size(type, n_per_row);
|
||||
result = quantize_iq2_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
||||
GGML_ASSERT(result == row_size * nrows);
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
{
|
||||
GGML_ASSERT(start % QK_K == 0);
|
||||
|
@ -19786,6 +19838,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|||
result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
||||
GGML_ASSERT(result == row_size * nrows);
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
{
|
||||
GGML_ASSERT(start % QK4_NL == 0);
|
||||
GGML_ASSERT(start % n_per_row == 0);
|
||||
size_t start_row = start / n_per_row;
|
||||
size_t row_size = ggml_row_size(type, n_per_row);
|
||||
result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
|
||||
GGML_ASSERT(result == row_size * nrows);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
size_t elemsize = sizeof(ggml_fp16_t);
|
||||
|
|
4
ggml.h
4
ggml.h
|
@ -351,6 +351,8 @@ extern "C" {
|
|||
GGML_TYPE_IQ1_S = 19,
|
||||
GGML_TYPE_IQ4_NL = 20,
|
||||
GGML_TYPE_IQ3_S = 21,
|
||||
GGML_TYPE_IQ2_S = 22,
|
||||
GGML_TYPE_IQ4_XS = 23,
|
||||
GGML_TYPE_I8,
|
||||
GGML_TYPE_I16,
|
||||
GGML_TYPE_I32,
|
||||
|
@ -391,6 +393,8 @@ extern "C" {
|
|||
GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
|
194
llama.cpp
194
llama.cpp
|
@ -1641,6 +1641,7 @@ struct llama_cparams {
|
|||
float yarn_attn_factor;
|
||||
float yarn_beta_fast;
|
||||
float yarn_beta_slow;
|
||||
float defrag_thold;
|
||||
|
||||
bool mul_mat_q;
|
||||
bool offload_kqv;
|
||||
|
@ -2579,9 +2580,11 @@ struct llama_model_loader {
|
|||
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
|
||||
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
|
||||
case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
|
||||
case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break;
|
||||
case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
|
||||
case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
|
||||
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
|
||||
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
|
||||
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
|
||||
default:
|
||||
{
|
||||
|
@ -2933,10 +2936,13 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
|||
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_XS:return "Q3_K - Extra small";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
|
||||
|
||||
|
@ -4894,8 +4900,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL)
|
||||
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL")
|
||||
#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE)
|
||||
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, and Kompute")
|
||||
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
|
||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488")
|
||||
if (hparams.f_max_alibi_bias > 0.0f) {
|
||||
|
@ -5114,16 +5120,16 @@ struct llm_build_context {
|
|||
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const int id = ids[i];
|
||||
for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == n_kv) {
|
||||
if (i == id || id == ids.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int nm = 1;
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < n_kv && (int) ids[i + nm] == id + nm) {
|
||||
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
|
@ -5155,6 +5161,8 @@ struct llm_build_context {
|
|||
i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
@ -7935,6 +7943,8 @@ static int llama_decode_internal(
|
|||
batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
|
@ -7953,8 +7963,6 @@ static int llama_decode_internal(
|
|||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
|
||||
|
@ -8004,6 +8012,18 @@ static int llama_decode_internal(
|
|||
}
|
||||
}
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.defrag_thold >= 0.0f) {
|
||||
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
||||
|
||||
llama_kv_cache_defrag(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
// requires GGML_PERF to be defined
|
||||
|
@ -8095,12 +8115,16 @@ static int llama_decode_internal(
|
|||
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
||||
const uint32_t n_used = kv_self.used;
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// number of cells moved
|
||||
uint32_t n_moves = 0;
|
||||
|
@ -8124,15 +8148,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
|
||||
// found a hole - fill it with data from the end of the cache
|
||||
|
||||
// determine the size of the hole
|
||||
uint32_t nh = 1;
|
||||
|
||||
// determine the size of the hole
|
||||
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
|
||||
nh++;
|
||||
}
|
||||
|
||||
// starting from the end, find nh non-empty cells
|
||||
// each move requires 6*n_layer tensors (see build_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
// - x2 for keys and values
|
||||
//
|
||||
if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
|
||||
// the graph is too big, we cannot move more cells
|
||||
break;
|
||||
}
|
||||
|
||||
uint32_t nf = 0;
|
||||
uint32_t is = n_kv - 1;
|
||||
|
||||
// starting from the end, find nh non-empty cells
|
||||
for (; is > i0; --is) {
|
||||
const auto & cell1 = kv_self.cells[is];
|
||||
|
||||
|
@ -8153,11 +8188,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
|
||||
nf = 0;
|
||||
|
||||
uint32_t i1 = is;
|
||||
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
// go back and move the nf cells to the hole
|
||||
for (uint32_t i1 = is; i1 < n_kv; ++i1) {
|
||||
const auto & cell1 = kv_self.cells[i1];
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -8167,11 +8208,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
// move the cell meta data
|
||||
kv_self.cells[i0 + nf] = cell1;
|
||||
|
||||
n_moves++;
|
||||
// clear the old cell and move the head there
|
||||
cell1 = llama_kv_cell();
|
||||
kv_self.head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
cont = true;
|
||||
}
|
||||
|
||||
nf++;
|
||||
|
||||
if (nf == nh) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
@ -8180,15 +8233,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
return;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
|
||||
kv_self.head = n_used;
|
||||
kv_self.used = n_used;
|
||||
|
||||
// zero the rest of the cells
|
||||
for (uint32_t i = n_used; i < n_kv; ++i) {
|
||||
kv_self.cells[i] = llama_kv_cell();
|
||||
}
|
||||
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
|
@ -8200,9 +8247,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
// likely not worth the effort, as we have ggml_graph based defrag
|
||||
//
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
|
@ -8271,9 +8315,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
|
||||
#endif
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
|
@ -10761,31 +10805,47 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if (new_type != GGML_TYPE_Q8_0) {
|
||||
new_type = GGML_TYPE_Q6_K;
|
||||
}
|
||||
} else if (name == "token_embd.weight") {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
|
||||
new_type = GGML_TYPE_Q2_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
|
||||
new_type = GGML_TYPE_IQ3_S;
|
||||
}
|
||||
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_IQ3_S;
|
||||
}
|
||||
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
|
||||
if (name.find("attn_v.weight") != std::string::npos) {
|
||||
if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
|
||||
else new_type = GGML_TYPE_Q2_K;
|
||||
else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
|
||||
++qs.i_attention_wv;
|
||||
}
|
||||
else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (name.find("ffn_down") != std::string::npos) {
|
||||
if (qs.i_ffn_down < qs.n_ffn_down/8) new_type = GGML_TYPE_Q2_K;
|
||||
if (qs.i_ffn_down < qs.n_ffn_down/8) {
|
||||
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
|
||||
}
|
||||
++qs.i_ffn_down;
|
||||
}
|
||||
else if (name.find("attn_output.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) new_type = GGML_TYPE_IQ2_XXS;
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
} else {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) new_type = GGML_TYPE_IQ2_XXS;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
|
||||
}
|
||||
}
|
||||
} else if (name.find("attn_v.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
|
||||
|
@ -10795,7 +10855,13 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_Q3_K : GGML_TYPE_IQ3_XXS;
|
||||
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_S && qs.model.hparams.n_gqa() >= 4) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_S && qs.model.hparams.n_gqa() >= 4) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
|
@ -10807,7 +10873,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL && qs.model.hparams.n_gqa() >= 4) {
|
||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||
|
@ -10833,13 +10899,19 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
// TODO: explore better strategies
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_IQ2_S;
|
||||
}
|
||||
} else if (name.find("attn_q.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_IQ2_S;
|
||||
}
|
||||
} else if (name.find("ffn_down") != std::string::npos) {
|
||||
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
|
||||
int i_layer = info.first, n_layer = info.second;
|
||||
|
@ -10870,8 +10942,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
||||
}
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL && !qs.has_imatrix) {
|
||||
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q5_K;
|
||||
else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
|
||||
|
@ -10888,15 +10960,15 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
} else if (name.find("attn_output.weight") != std::string::npos) {
|
||||
if (arch != LLM_ARCH_FALCON) {
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
} else {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_Q3_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
|
||||
|
@ -10915,7 +10987,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
else if (name.find("ffn_gate") != std::string::npos) {
|
||||
auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
|
||||
int i_layer = info.first, n_layer = info.second;
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
++qs.i_ffn_gate;
|
||||
|
@ -10923,7 +10995,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
else if (name.find("ffn_up") != std::string::npos) {
|
||||
auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
|
||||
int i_layer = info.first, n_layer = info.second;
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
++qs.i_ffn_up;
|
||||
|
@ -10942,8 +11014,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
//}
|
||||
bool convert_incompatible_tensor = false;
|
||||
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
|
||||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K ||
|
||||
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS ||
|
||||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
|
||||
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
|
||||
new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || new_type == GGML_TYPE_IQ3_S) {
|
||||
int nx = tensor->ne[0];
|
||||
int ny = tensor->ne[1];
|
||||
|
@ -10958,14 +11030,16 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||
switch (new_type) {
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K: new_type = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
||||
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
||||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
||||
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
||||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
||||
}
|
||||
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
||||
|
@ -10991,7 +11065,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
// K-quants
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_XS: quantized_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XS: quantized_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_M:
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break;
|
||||
|
@ -11002,9 +11076,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: quantized_type = GGML_TYPE_IQ2_XXS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XS: quantized_type = GGML_TYPE_IQ2_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_S: quantized_type = GGML_TYPE_IQ2_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_M: quantized_type = GGML_TYPE_IQ2_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: quantized_type = GGML_TYPE_IQ4_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: quantized_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: quantized_type = GGML_TYPE_IQ3_S; break;
|
||||
|
||||
|
@ -11180,6 +11257,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
}
|
||||
if ((new_type == GGML_TYPE_IQ2_XXS ||
|
||||
new_type == GGML_TYPE_IQ2_XS ||
|
||||
new_type == GGML_TYPE_IQ2_S ||
|
||||
new_type == GGML_TYPE_IQ1_S ||
|
||||
(new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
|
||||
LLAMA_LOG_ERROR("\n\n============================================================\n");
|
||||
|
@ -11635,6 +11713,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.yarn_beta_fast =*/ 32.0f,
|
||||
/*.yarn_beta_slow =*/ 1.0f,
|
||||
/*.yarn_orig_ctx =*/ 0,
|
||||
/*.defrag_thold =*/ -1.0f,
|
||||
/*.cb_eval =*/ nullptr,
|
||||
/*.cb_eval_user_data =*/ nullptr,
|
||||
/*.type_k =*/ GGML_TYPE_F16,
|
||||
|
@ -11799,6 +11878,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.do_pooling = params.do_pooling;
|
||||
|
@ -12000,7 +12080,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
|
||||
// buffer used to store the computation graph and the tensor meta data
|
||||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
|
||||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
|
||||
|
||||
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
|
||||
|
||||
|
|
6
llama.h
6
llama.h
|
@ -107,12 +107,15 @@ extern "C" {
|
|||
LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
@ -243,6 +246,7 @@ extern "C" {
|
|||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
|
|
@ -1916,9 +1916,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
|
||||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
||||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S,
|
||||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S,
|
||||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
||||
};
|
||||
|
||||
// unary ops
|
||||
|
|
|
@ -150,6 +150,7 @@ int main(int argc, char * argv[]) {
|
|||
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
|
||||
const float max_quantization_error =
|
||||
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
|
||||
type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
|
||||
type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
|
||||
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
|
||||
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
|
||||
|
@ -168,7 +169,8 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
|
||||
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
|
||||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
||||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
|
||||
? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
failed = !(vec_dot_error < max_allowed_error);
|
||||
num_failed += failed;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue