diff --git a/common/arg.cpp b/common/arg.cpp index 7c5c5e5cd..61efe2126 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2047,6 +2047,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex common_log_set_timestamps(common_log_main(), true); } ).set_env("LLAMA_LOG_TIMESTAMPS")); + add_opt(common_arg( + {"-rtrp", "--runtime-repack"}, + string_format("Allow runtime requantization and repacking of Q4_0 to enable optimized GEMM and GEMV kernels (default: %d)", params.runtime_repack), + [](common_params & params) { + params.runtime_repack = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); return ctx_arg; } diff --git a/common/common.cpp b/common/common.cpp index 19674af15..07d9928d3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -983,7 +983,7 @@ struct llama_model_params common_model_params_to_llama(const common_params & par mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; - mparams.use_mmap = params.use_mmap; + mparams.use_mmap = params.use_mmap && !params.runtime_repack; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; if (params.kv_overrides.empty()) { @@ -1056,6 +1056,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; + cparams.runtime_repack = params.runtime_repack; if (params.reranking) { cparams.embeddings = true; diff --git a/common/common.h b/common/common.h index 727f85baa..71fd47fcb 100644 --- a/common/common.h +++ b/common/common.h @@ -271,6 +271,8 @@ struct common_params { bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + bool runtime_repack = false; // runtime repack weight for optimized kernels + std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 1eddfd0db..732ff1d2a 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -170,6 +170,7 @@ struct cmd_params { std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; + std::vector runtime_repack; ggml_numa_strategy numa; int reps; ggml_sched_priority prio; @@ -202,6 +203,7 @@ static const cmd_params cmd_params_defaults = { /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, + /* runtime_repack */ {false}, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* prio */ GGML_SCHED_PRIO_NORMAL, @@ -240,6 +242,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); + printf(" -rtrp, --runtime_repack <0|1> (default: %s)\n", join(cmd_params_defaults.runtime_repack, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio); @@ -502,6 +505,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.embeddings.insert(params.embeddings.end(), p.begin(), p.end()); + } else if (arg == "-rtrp" || arg == "--runtime_repack") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.runtime_repack.insert(params.runtime_repack.end(), p.begin(), p.end()); } else if (arg == "-ts" || arg == "--tensor-split") { if (++i >= argc) { invalid_param = true; @@ -570,27 +580,28 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } // set defaults - if (params.model.empty()) { params.model = cmd_params_defaults.model; } - if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } - if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } - if (params.n_pg.empty()) { params.n_pg = cmd_params_defaults.n_pg; } - if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } - if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; } - if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; } - if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; } - if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } - if (params.rpc_servers.empty()) { params.rpc_servers = cmd_params_defaults.rpc_servers; } - if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } - if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } - if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } - if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } - if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } - if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } - if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } - if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } - if (params.cpu_mask.empty()) { params.cpu_mask = cmd_params_defaults.cpu_mask; } - if (params.cpu_strict.empty()) { params.cpu_strict = cmd_params_defaults.cpu_strict; } - if (params.poll.empty()) { params.poll = cmd_params_defaults.poll; } + if (params.model.empty()) { params.model = cmd_params_defaults.model; } + if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } + if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } + if (params.n_pg.empty()) { params.n_pg = cmd_params_defaults.n_pg; } + if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } + if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; } + if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; } + if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; } + if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } + if (params.rpc_servers.empty()) { params.rpc_servers = cmd_params_defaults.rpc_servers; } + if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } + if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } + if (params.no_kv_offload.empty()) { params.no_kv_offload = cmd_params_defaults.no_kv_offload; } + if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } + if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } + if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } + if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } + if (params.runtime_repack.empty()){ params.runtime_repack = cmd_params_defaults.runtime_repack; } + if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } + if (params.cpu_mask.empty()) { params.cpu_mask = cmd_params_defaults.cpu_mask; } + if (params.cpu_strict.empty()) { params.cpu_strict = cmd_params_defaults.cpu_strict; } + if (params.poll.empty()) { params.poll = cmd_params_defaults.poll; } return params; } @@ -616,6 +627,7 @@ struct cmd_params_instance { std::vector tensor_split; bool use_mmap; bool embeddings; + bool runtime_repack; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -653,6 +665,7 @@ struct cmd_params_instance { cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; + cparams.runtime_repack = runtime_repack; return cparams; } @@ -670,6 +683,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & ts : params.tensor_split) for (const auto & mmp : params.use_mmap) for (const auto & embd : params.embeddings) + for (const auto & rtrp : params.runtime_repack) for (const auto & nb : params.n_batch) for (const auto & nub : params.n_ubatch) for (const auto & tk : params.type_k) @@ -685,26 +699,27 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { - /* .model = */ m, - /* .n_prompt = */ n_prompt, - /* .n_gen = */ 0, - /* .n_batch = */ nb, - /* .n_ubatch = */ nub, - /* .type_k = */ tk, - /* .type_v = */ tv, - /* .n_threads = */ nt, - /* .cpu_mask = */ cm, - /* .cpu_strict = */ cs, - /* .poll = */ pl, - /* .n_gpu_layers = */ nl, - /* .rpc_servers = */ rpc, - /* .split_mode = */ sm, - /* .main_gpu = */ mg, - /* .no_kv_offload= */ nkvo, - /* .flash_attn = */ fa, - /* .tensor_split = */ ts, - /* .use_mmap = */ mmp, - /* .embeddings = */ embd, + /* .model = */ m, + /* .n_prompt = */ n_prompt, + /* .n_gen = */ 0, + /* .n_batch = */ nb, + /* .n_ubatch = */ nub, + /* .type_k = */ tk, + /* .type_v = */ tv, + /* .n_threads = */ nt, + /* .cpu_mask = */ cm, + /* .cpu_strict = */ cs, + /* .poll = */ pl, + /* .n_gpu_layers = */ nl, + /* .rpc_servers = */ rpc, + /* .split_mode = */ sm, + /* .main_gpu = */ mg, + /* .no_kv_offload = */ nkvo, + /* .flash_attn = */ fa, + /* .tensor_split = */ ts, + /* .use_mmap = */ static_cast(mmp) && !static_cast(rtrp), + /* .embeddings = */ embd, + /* .runtime_repack= */ rtrp, }; instances.push_back(instance); } @@ -714,26 +729,27 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { - /* .model = */ m, - /* .n_prompt = */ 0, - /* .n_gen = */ n_gen, - /* .n_batch = */ nb, - /* .n_ubatch = */ nub, - /* .type_k = */ tk, - /* .type_v = */ tv, - /* .n_threads = */ nt, - /* .cpu_mask = */ cm, - /* .cpu_strict = */ cs, - /* .poll = */ pl, - /* .n_gpu_layers = */ nl, - /* .rpc_servers = */ rpc, - /* .split_mode = */ sm, - /* .main_gpu = */ mg, - /* .no_kv_offload= */ nkvo, - /* .flash_attn = */ fa, - /* .tensor_split = */ ts, - /* .use_mmap = */ mmp, - /* .embeddings = */ embd, + /* .model = */ m, + /* .n_prompt = */ 0, + /* .n_gen = */ n_gen, + /* .n_batch = */ nb, + /* .n_ubatch = */ nub, + /* .type_k = */ tk, + /* .type_v = */ tv, + /* .n_threads = */ nt, + /* .cpu_mask = */ cm, + /* .cpu_strict = */ cs, + /* .poll = */ pl, + /* .n_gpu_layers = */ nl, + /* .rpc_servers = */ rpc, + /* .split_mode = */ sm, + /* .main_gpu = */ mg, + /* .no_kv_offload = */ nkvo, + /* .flash_attn = */ fa, + /* .tensor_split = */ ts, + /* .use_mmap = */ static_cast(mmp) && !static_cast(rtrp), + /* .embeddings = */ embd, + /* .runtime_repack= */ rtrp, }; instances.push_back(instance); } @@ -743,26 +759,27 @@ static std::vector get_cmd_params_instances(const cmd_param continue; } cmd_params_instance instance = { - /* .model = */ m, - /* .n_prompt = */ n_pg.first, - /* .n_gen = */ n_pg.second, - /* .n_batch = */ nb, - /* .n_ubatch = */ nub, - /* .type_k = */ tk, - /* .type_v = */ tv, - /* .n_threads = */ nt, - /* .cpu_mask = */ cm, - /* .cpu_strict = */ cs, - /* .poll = */ pl, - /* .n_gpu_layers = */ nl, - /* .rpc_servers = */ rpc, - /* .split_mode = */ sm, - /* .main_gpu = */ mg, - /* .no_kv_offload= */ nkvo, - /* .flash_attn = */ fa, - /* .tensor_split = */ ts, - /* .use_mmap = */ mmp, - /* .embeddings = */ embd, + /* .model = */ m, + /* .n_prompt = */ n_pg.first, + /* .n_gen = */ n_pg.second, + /* .n_batch = */ nb, + /* .n_ubatch = */ nub, + /* .type_k = */ tk, + /* .type_v = */ tv, + /* .n_threads = */ nt, + /* .cpu_mask = */ cm, + /* .cpu_strict = */ cs, + /* .poll = */ pl, + /* .n_gpu_layers = */ nl, + /* .rpc_servers = */ rpc, + /* .split_mode = */ sm, + /* .main_gpu = */ mg, + /* .no_kv_offload = */ nkvo, + /* .flash_attn = */ fa, + /* .tensor_split = */ ts, + /* .use_mmap = */ static_cast(mmp) && !static_cast(rtrp), + /* .embeddings = */ embd, + /* .runtime_repack= */ rtrp, }; instances.push_back(instance); } @@ -804,6 +821,7 @@ struct test { std::vector tensor_split; bool use_mmap; bool embeddings; + bool runtime_repack; int n_prompt; int n_gen; std::string test_time; @@ -833,6 +851,7 @@ struct test { tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; + runtime_repack = inst.runtime_repack; n_prompt = inst.n_prompt; n_gen = inst.n_gen; // RFC 3339 date-time format @@ -889,7 +908,7 @@ struct test { "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", - "tensor_split", "use_mmap", "embeddings", + "tensor_split", "use_mmap", "embeddings", "runtime_repack", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", @@ -911,7 +930,7 @@ struct test { if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "runtime_repack") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -947,7 +966,7 @@ struct test { ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(runtime_repack), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()) @@ -1135,6 +1154,9 @@ struct markdown_printer : public printer { if (field == "test") { return 13; } + if (field == "runtime_repack") { + return 6; + } int width = std::max((int)field.length(), 10); @@ -1169,6 +1191,9 @@ struct markdown_printer : public printer { if (field == "tensor_split") { return "ts"; } + if (field == "runtime_repack") { + return "repack"; + } return field; } @@ -1227,6 +1252,9 @@ struct markdown_printer : public printer { if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) { fields.emplace_back("embeddings"); } + if (params.runtime_repack.size() > 1 || params.runtime_repack != cmd_params_defaults.runtime_repack) { + fields.emplace_back("runtime_repack"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 125413d1b..20e9e6e34 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -305,7 +305,19 @@ extern "C" { GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); - // CPU buffer types are always available + // + // CPU backend + // + + GGML_API ggml_backend_t ggml_backend_cpu_init(void); + + GGML_API bool ggml_backend_is_cpu (ggml_backend_t backend); + GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); + GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); + GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); + GGML_API void ggml_backend_cpu_set_runtime_repack(ggml_backend_t backend_cpu, bool runtime_repack); + + // Create a backend buffer from an existing pointer GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 81f62ff4f..c2eb35791 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -3476,3 +3476,102 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * } } } + +static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(t->ne[0] % 8 == 0); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + + // Do in-place transformation. Allocate scratch buffer + size_t size = sizeof(block_q4_0x4) * t->ne[0] / QK4_0; + if (size > *psize) { + uint8_t *new_mem = realloc(*pmem, size); + if (!new_mem) { + return -1; + } + *pmem = new_mem; + *psize = size; + } + block_q4_0x4 *dst = (block_q4_0x4*) *pmem; + block_q4_0 *src = (block_q4_0*) t->data; + block_q4_0 dst_tmp[4]; + int n = t->ne[0]; + int nrow = t->ne[1]; // Number of rows + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK4_0; + for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) { + int cnt = 0; + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + dst[cnt++] = make_block_q4_0x4(dst_tmp, interleave_block, 0x88); + } + memcpy(src, dst, size); + src += cnt * 4; + } + return 0; +} + +static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(t->ne[0] % 8 == 0); + GGML_ASSERT(interleave_block == 8); + + // Do in-place transformation. Allocate scratch buffer + size_t size = sizeof(block_q4_0x8) * t->ne[0] / QK4_0; + if (size > *psize) { + uint8_t *new_mem = realloc(*pmem, size); + if (!new_mem) { + return -1; + } + *pmem = new_mem; + *psize = size; + } + block_q4_0x8 *dst = (block_q4_0x8*) *pmem; + block_q4_0 *src = (block_q4_0*) t->data; + block_q4_0 dst_tmp[8]; + int n = t->ne[0]; + int nrow = t->ne[1]; // Number of rows + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK4_0; + for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) { + int cnt = 0; + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + dst[cnt++] = make_block_q4_0x8(dst_tmp, interleave_block, 0x88); + } + memcpy(src, dst, size); + src += cnt * 4; + } + return 0; +} + +// Prepare for optimized kernels if applicable +void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize) { + UNUSED(cur); + UNUSED(pmem); + UNUSED(psize); + +#if defined(__ARM_ARCH) + if (cur->type == GGML_TYPE_Q4_0) { + if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { + if (repack_q4_0_to_q4_0_8_bl(cur, 8, pmem, psize) == 0) { + cur->type = GGML_TYPE_Q4_0_8_8; + } + } + else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (repack_q4_0_to_q4_0_4_bl(cur, 8, pmem, psize) == 0) { + cur->type = GGML_TYPE_Q4_0_4_8; + } + } + else if (ggml_cpu_has_neon()) { + if (repack_q4_0_to_q4_0_4_bl(cur, 4, pmem, psize) == 0) { + cur->type = GGML_TYPE_Q4_0_4_4; + } + } + } +#endif +} diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h index 517babaf1..f68d66f6d 100644 --- a/ggml/src/ggml-aarch64.h +++ b/ggml/src/ggml-aarch64.h @@ -33,6 +33,8 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 0b8ebac53..6e2389c16 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -12,6 +12,7 @@ #include "ggml-backend-impl.h" #include "ggml-alloc.h" #include "ggml-impl.h" +#include "ggml-aarch64.h" #include #include @@ -716,6 +717,628 @@ ggml_backend_t ggml_backend_init_best(void) { return ggml_backend_dev_init(dev, NULL); } +// backend CPU + +static const char * ggml_backend_cpu_buffer_get_name(ggml_backend_buffer_t buffer) { + return "CPU"; + + GGML_UNUSED(buffer); +} + +static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + uintptr_t data = (uintptr_t)buffer->context; + + // align the buffer + if (data % TENSOR_ALIGNMENT != 0) { + data = GGML_PAD(data, TENSOR_ALIGNMENT); + } + + return (void *)data; +} + +static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_aligned_free(buffer->context, buffer->size); +} + +static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + +static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { + /* .get_name = */ ggml_backend_cpu_buffer_get_name, + /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { + /* .get_name = */ ggml_backend_cpu_buffer_get_name, + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + auto alloc_size = size; + if (alloc_size == 0) { + alloc_size = 1; + } + + void * data = ggml_aligned_malloc(alloc_size); + + if (data == NULL) { + GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, alloc_size); + return NULL; + } + + return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, alloc_size); +} + +static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type; +} + +#ifdef GGML_USE_CPU_HBM + +// buffer type HBM + +#include + +static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_HBM"; + + GGML_UNUSED(buft); +} + +static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) { + return "CPU_HBM"; + + GGML_UNUSED(buf); +} + +static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) { + hbw_free(buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + //void * ptr = hbw_malloc(size); + void * ptr; + int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size); + if (result != 0) { + GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size); + return NULL; + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name; + buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type_hbm; +} +#endif + +struct ggml_backend_cpu_context { + int n_threads; + ggml_threadpool_t threadpool; + + uint8_t * work_data; + size_t work_size; + + bool runtime_repack; + uint8_t * scratch_memory; + size_t scratch_size; + + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) { + return "CPU"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_cpu_free(ggml_backend_t backend) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + delete[] cpu_ctx->work_data; + free(cpu_ctx->scratch_memory); // free the scratch memory allocated by C module + delete cpu_ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(backend); +} + +struct ggml_backend_plan_cpu { + struct ggml_cplan cplan; + struct ggml_cgraph cgraph; +}; + +static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu; + + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); + cpu_plan->cgraph = *cgraph; // FIXME: deep copy + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size]; + if (cpu_plan->cplan.work_data == NULL) { + delete cpu_plan; + return NULL; + } + } + + cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; + cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return cpu_plan; +} + +static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + delete[] cpu_plan->cplan.work_data; + delete cpu_plan; + + GGML_UNUSED(backend); +} + +static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + GGML_UNUSED(backend); +} + +static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + if (cpu_ctx->runtime_repack) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + if (node->op == GGML_OP_MUL_MAT && node->src[0]->type == GGML_TYPE_Q4_0) { + // Prepare for optimized kernels if applicable. + ggml_prepare_optimal_kernel(node->src[0], &cpu_ctx->scratch_memory, &cpu_ctx->scratch_size); + } + } + } + + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); + + if (cpu_ctx->work_size < cplan.work_size) { + delete[] cpu_ctx->work_data; + cpu_ctx->work_data = new uint8_t[cplan.work_size]; + if (cpu_ctx->work_data == NULL) { + cpu_ctx->work_size = 0; + return GGML_STATUS_ALLOC_FAILED; + } + cpu_ctx->work_size = cplan.work_size; + } + cplan.work_data = (uint8_t *)cpu_ctx->work_data; + + cplan.abort_callback = cpu_ctx->abort_callback; + cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return ggml_graph_compute(cgraph, &cplan); +} + +static const struct ggml_backend_i ggml_backend_cpu_i = { + /* .get_name = */ ggml_backend_cpu_get_name, + /* .free = */ ggml_backend_cpu_free, + /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cpu_graph_compute, + /* .supports_op = */ NULL, + /* .supports_buft = */ NULL, + /* .offload_op = */ NULL, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_cpu_guid(void) { + static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 }; + return &guid; +} + +ggml_backend_t ggml_backend_cpu_init(void) { + struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context; + if (ctx == NULL) { + return NULL; + } + + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->threadpool = NULL; + ctx->work_data = NULL; + ctx->work_size = 0; + ctx->abort_callback = NULL; + ctx->abort_callback_data = NULL; + ctx->runtime_repack = false; + ctx->scratch_memory = NULL; + ctx->scratch_size = 0; + + ggml_backend_t cpu_backend = new ggml_backend { + /* .guid = */ ggml_backend_cpu_guid(), + /* .interface = */ ggml_backend_cpu_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ ctx, + }; + + if (cpu_backend == NULL) { + delete ctx; + return NULL; + } + + return cpu_backend; +} + +bool ggml_backend_is_cpu(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid()); +} + +void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + + if (ctx->threadpool && ctx->threadpool != threadpool) { + // already had a different threadpool, pause/suspend it before switching + ggml_threadpool_pause(ctx->threadpool); + } + ctx->threadpool = threadpool; +} + +void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; +} + +void ggml_backend_cpu_set_runtime_repack(ggml_backend_t backend_cpu, bool runtime_repack) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->runtime_repack = runtime_repack; +} + +ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { + GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); + return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size); +} + +//////////////////////// + +struct ggml_backend_cpu_device_context { + std::string description = "CPU"; + + ggml_backend_cpu_device_context() { +#ifdef __APPLE__ + size_t len = 0; + if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) { + description.resize(len); + sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT + } +#elif defined(__linux__) + FILE * f = fopen("/proc/cpuinfo", "r"); + if (f) { + char buf[1024]; + while (fgets(buf, sizeof(buf), f)) { + if (strncmp(buf, "model name", 10) == 0) { + char * p = strchr(buf, ':'); + if (p) { + p++; + while (std::isspace(*p)) { + p++; + } + while (std::isspace(p[strlen(p) - 1])) { + p[strlen(p) - 1] = '\0'; + } + description = p; + break; + } + } + } + fclose(f); + } +#elif defined(_WIN32) + HKEY hKey; + if (RegOpenKeyEx(HKEY_LOCAL_MACHINE, + TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"), + 0, + KEY_READ, + &hKey) == ERROR_SUCCESS) { + DWORD cpu_brand_size = 0; + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + NULL, + &cpu_brand_size) == ERROR_SUCCESS) { + description.resize(cpu_brand_size); + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + (LPBYTE)&description[0], // NOLINT + &cpu_brand_size) == ERROR_SUCCESS) { + if (description.find('\0') != std::string::npos) { + description.resize(description.find('\0')); + } + } + } + RegCloseKey(hKey); + } +#endif + } +}; + +static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) { + return "CPU"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) { + struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context; + + return ctx->description.c_str(); +} + +static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_CPU_FULL; + + GGML_UNUSED(dev); +} + +static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_cpu_device_get_name(dev); + props->description = ggml_backend_cpu_device_get_description(dev); + props->type = ggml_backend_cpu_device_get_type(dev); + ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_cpu_device_init(ggml_backend_dev_t dev, const char * params) { + return ggml_backend_cpu_init(); + + GGML_UNUSED(dev); + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + return ggml_backend_cpu_buffer_from_ptr(ptr, size); + + GGML_UNUSED(dev); + GGML_UNUSED(max_tensor_size); +} + +static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_CPY: + return + op->type != GGML_TYPE_IQ2_XXS && + op->type != GGML_TYPE_IQ2_XS && + op->type != GGML_TYPE_IQ1_S && + op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float + case GGML_OP_MUL_MAT: + return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type; + case GGML_OP_ROPE_BACK: + return op->src[2] == NULL && (op->op_params[2] & 4) == 0; + case GGML_OP_IM2COL_BACK: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_OUT_PROD: + return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32; + default: + return true; + } + + GGML_UNUSED(dev); +} + +static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft); + + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_cpu_device_i = { + /* .get_name = */ ggml_backend_cpu_device_get_name, + /* .get_description = */ ggml_backend_cpu_device_get_description, + /* .get_memory = */ ggml_backend_cpu_device_get_memory, + /* .get_type = */ ggml_backend_cpu_device_get_type, + /* .get_props = */ ggml_backend_cpu_device_get_props, + /* .init_backend = */ ggml_backend_cpu_device_init, + /* .get_buffer_type = */ ggml_backend_cpu_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_cpu_device_supports_op, + /* .supports_buft = */ ggml_backend_cpu_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +//////////////////////// + +static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) { + return "CPU"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + static ggml_backend_cpu_device_context ctx; + static ggml_backend_device ggml_backend_cpu_device = { + /* .iface = */ ggml_backend_cpu_device_i, + /* .reg = */ reg, + /* .context = */ &ctx, + }; + + return &ggml_backend_cpu_device; +} + +static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "ggml_backend_set_n_threads") == 0) { + return (void *)ggml_backend_cpu_set_n_threads; + } + return NULL; + + GGML_UNUSED(reg); +} + +static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = { + /* .get_name = */ ggml_backend_cpu_reg_get_name, + /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count, + /* .get_device = */ ggml_backend_cpu_reg_get_device, + /* .get_proc_address = */ ggml_backend_cpu_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_cpu_reg(void) { + static struct ggml_backend_reg ggml_backend_cpu_reg = { + /* .iface = */ ggml_backend_cpu_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_reg; +} + // multi-buffer buffer struct ggml_backend_multi_buffer_context { diff --git a/include/llama.h b/include/llama.h index ccb48f73c..6e5193cdf 100644 --- a/include/llama.h +++ b/include/llama.h @@ -334,11 +334,12 @@ extern "C" { // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. // TODO: move at the end of the struct - bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) - bool embeddings; // if true, extract embeddings (together with logits) - bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // whether to use flash attention [EXPERIMENTAL] - bool no_perf; // whether to measure performance timings + bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) + bool embeddings; // if true, extract embeddings (together with logits) + bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool no_perf; // whether to measure performance timings + bool runtime_repack; // runtime repack weight for optimized kernels // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 034441e1f..3c618e909 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2575,6 +2575,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool no_perf; + bool runtime_repack; enum llama_pooling_type pooling_type; @@ -17185,6 +17186,7 @@ static void llama_graph_compute( ggml_threadpool * threadpool) { if (lctx.backend_cpu != nullptr) { ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool); + ggml_backend_cpu_set_runtime_repack(lctx.backend_cpu, lctx.cparams.runtime_repack); ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data); } @@ -19120,6 +19122,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.no_perf =*/ true, + /*.runtime_repack =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -19383,6 +19386,7 @@ struct llama_context * llama_new_context_with_model( cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; + cparams.runtime_repack = params.runtime_repack; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;