diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index d9aca71ae..19881a505 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -222,8 +222,10 @@ extern "C" { // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL) GGML_API ggml_backend_t ggml_backend_init_best(void); - // Load a backend from a dynamic library + // Load a backend from a dynamic library and register it GGML_API ggml_backend_reg_t ggml_backend_load(const char * path); + // Unload a backend if loaded dynamically and unregister it + GGML_API void ggml_backend_unload(ggml_backend_reg_t reg); // Load all known backends from dynamic libraries GGML_API void ggml_backend_load_all(void); diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 6b68c956c..78096af18 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -1,6 +1,7 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" #include "ggml-impl.h" +#include #include #include @@ -45,8 +46,13 @@ #include "ggml-kompute.h" #endif +struct ggml_backend_reg_entry { + ggml_backend_reg_t reg; + void * handle; +}; + struct ggml_backend_registry { - std::vector backends; + std::vector backends; std::vector devices; ggml_backend_registry() { @@ -82,7 +88,13 @@ struct ggml_backend_registry { #endif } - void register_backend(ggml_backend_reg_t reg) { + ~ggml_backend_registry() { + while (!backends.empty()) { + ggml_backend_unload(backends.back().reg); + } + } + + void register_backend(ggml_backend_reg_t reg, void * handle = nullptr) { if (!reg) { return; } @@ -91,7 +103,7 @@ struct ggml_backend_registry { GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); #endif - backends.push_back(reg); + backends.push_back({ reg, handle }); for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { register_device(ggml_backend_reg_dev_get(reg, i)); } @@ -126,7 +138,7 @@ size_t ggml_backend_reg_count() { ggml_backend_reg_t ggml_backend_reg_get(size_t index) { GGML_ASSERT(index < ggml_backend_reg_count()); - return get_reg().backends[index]; + return get_reg().backends[index].reg; } ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) { @@ -136,7 +148,7 @@ ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) { return reg; } } - return NULL; + return nullptr; } // Device enumeration @@ -156,7 +168,7 @@ ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) { return dev; } } - return NULL; + return nullptr; } ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) { @@ -166,14 +178,14 @@ ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) { return dev; } } - return NULL; + return nullptr; } // Convenience functions ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) { ggml_backend_dev_t dev = ggml_backend_dev_by_name(name); if (!dev) { - return NULL; + return nullptr; } return ggml_backend_dev_init(dev, params); } @@ -181,7 +193,7 @@ ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) { ggml_backend_dev_t dev = ggml_backend_dev_by_type(type); if (!dev) { - return NULL; + return nullptr; } return ggml_backend_dev_init(dev, params); } @@ -192,9 +204,9 @@ ggml_backend_t ggml_backend_init_best(void) { dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); } if (!dev) { - return NULL; + return nullptr; } - return ggml_backend_dev_init(dev, NULL); + return ggml_backend_dev_init(dev, nullptr); } #ifdef _WIN32 @@ -214,45 +226,66 @@ ggml_backend_reg_t ggml_backend_load(const char * path) { HMODULE handle = LoadLibraryA(path); if (!handle) { GGML_LOG_ERROR("%s: failed to load %s: %lu\n", __func__, path, GetLastError()); - return NULL; + return nullptr; } ggml_backend_init_t backend_init = (ggml_backend_init_t) GetProcAddress(handle, "ggml_backend_init"); if (!backend_init) { GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s: %lu\n", __func__, path, GetLastError()); FreeLibrary(handle); - return NULL; + return nullptr; } - ggml_backend_reg_t reg = backend_init(); - if (!reg) { - GGML_LOG_ERROR("%s: failed to initialize backend from %s\n", __func__, path); - FreeLibrary(handle); - return NULL; - } - GGML_LOG_DEBUG("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path); - ggml_backend_register(reg); - return reg; #else void * handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); if (!handle) { GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path, dlerror()); - return NULL; + return nullptr; } auto * backend_init = (ggml_backend_init_t) dlsym(handle, "ggml_backend_init"); if (!backend_init) { GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s: %s\n", __func__, path, dlerror()); dlclose(handle); - return NULL; + return nullptr; } +#endif ggml_backend_reg_t reg = backend_init(); if (!reg) { GGML_LOG_ERROR("%s: failed to initialize backend from %s\n", __func__, path); dlclose(handle); - return NULL; + return nullptr; } GGML_LOG_DEBUG("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path); - ggml_backend_register(reg); + get_reg().register_backend(reg, handle); return reg; +} + +void ggml_backend_unload(ggml_backend_reg_t reg) { + auto it = std::find_if(get_reg().backends.begin(), get_reg().backends.end(), + [reg](ggml_backend_reg_entry entry) { return entry.reg == reg; }); + + if (it == get_reg().backends.end()) { + GGML_LOG_ERROR("%s: backend not found\n", __func__); + return; + } + + GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, ggml_backend_reg_name(reg)); + + // remove devices + get_reg().devices.erase( + std::remove_if(get_reg().devices.begin(), get_reg().devices.end(), + [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }), + get_reg().devices.end()); + + // unload library + if (it->handle) { +#ifdef _WIN32 + FreeLibrary((HMODULE) it->handle); +#else + dlclose(it->handle); #endif + } + + // remove backend + get_reg().backends.erase(it); } void ggml_backend_load_all() {