ggml : add backend registry / device interfaces to BLAS backend (#9752)
* ggml : add backend registry / device interfaces to BLAS backend * fix mmap usage when using host buffers
This commit is contained in:
parent
f1af42fa8c
commit
6374743747
8 changed files with 293 additions and 99 deletions
|
@ -22,10 +22,6 @@
|
|||
# include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_BLAS
|
||||
# include "ggml-blas.h"
|
||||
#endif
|
||||
|
||||
// TODO: replace with ggml API call
|
||||
#define QK_K 256
|
||||
|
||||
|
@ -3288,9 +3284,8 @@ struct llama_context {
|
|||
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
||||
|
||||
std::vector<ggml_backend_t> backends;
|
||||
#ifdef GGML_USE_BLAS
|
||||
ggml_backend_t backend_blas = nullptr;
|
||||
#endif
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
|
@ -8908,7 +8903,8 @@ static bool llm_load_tensors(
|
|||
bufs.reserve(n_max_backend_buffer);
|
||||
|
||||
// check if this backend device supports buffer_from_host_ptr
|
||||
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
|
||||
// when using a host buffer as the CPU bakcend buffer, use the CPU device to prioritize using buffer_from_host_ptr over the host buffer
|
||||
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft == llama_default_buffer_type_cpu(model, true) ? ggml_backend_cpu_buffer_type() : buft);
|
||||
bool buffer_from_host_ptr_supported = false;
|
||||
if (dev) {
|
||||
ggml_backend_dev_props props;
|
||||
|
@ -17048,17 +17044,19 @@ static void llama_graph_compute(
|
|||
int n_threads,
|
||||
ggml_threadpool * threadpool) {
|
||||
if (lctx.backend_cpu != nullptr) {
|
||||
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
||||
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
|
||||
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
|
||||
}
|
||||
#ifdef GGML_USE_BLAS
|
||||
if (lctx.backend_blas != nullptr) {
|
||||
ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_backend_sched_graph_compute_async(lctx.sched, gf);
|
||||
// set the number of threads for all the backends
|
||||
for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
|
||||
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
||||
}
|
||||
|
||||
auto err = ggml_backend_sched_graph_compute_async(lctx.sched, gf);
|
||||
if (err != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
|
||||
}
|
||||
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
}
|
||||
|
@ -19110,9 +19108,16 @@ struct llama_model * llama_load_model_from_file(
|
|||
// TODO: rework API to give user more control over device selection
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
// skip the CPU backend since it is handled separately
|
||||
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU_FULL) {
|
||||
model->devices.push_back(dev);
|
||||
switch (ggml_backend_dev_type(dev)) {
|
||||
case GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
case GGML_BACKEND_DEVICE_TYPE_CPU_FULL:
|
||||
// skip CPU backends since they are `handled separately
|
||||
break;
|
||||
|
||||
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
case GGML_BACKEND_DEVICE_TYPE_GPU_FULL:
|
||||
model->devices.push_back(dev);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -19407,14 +19412,19 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_BLAS
|
||||
ctx->backend_blas = ggml_backend_blas_init();
|
||||
if (ctx->backend_blas == nullptr) {
|
||||
LLAMA_LOG_WARN("%s: failed to initialize BLAS backend\n", __func__);
|
||||
} else {
|
||||
ctx->backends.push_back(ctx->backend_blas);
|
||||
// add other backends (such as BLAS)
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
ctx->backend_cpu = ggml_backend_cpu_init();
|
||||
if (ctx->backend_cpu == nullptr) {
|
||||
|
@ -19424,6 +19434,18 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
ctx->backends.push_back(ctx->backend_cpu);
|
||||
|
||||
// create a list of the set_n_threads functions in the backends
|
||||
for (auto * backend : ctx->backends) {
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
ctx->set_n_threads_fns.emplace_back(backend, ggml_backend_set_n_threads_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue