add ggml_backend_unload

This commit is contained in:
slaren 2024-11-24 00:59:39 +01:00
parent d5a3beb0e0
commit ccd8df8a9d
2 changed files with 62 additions and 27 deletions

View file

@ -222,8 +222,10 @@ extern "C" {
// = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL) // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)
GGML_API ggml_backend_t ggml_backend_init_best(void); GGML_API ggml_backend_t ggml_backend_init_best(void);
// Load a backend from a dynamic library // Load a backend from a dynamic library and register it
GGML_API ggml_backend_reg_t ggml_backend_load(const char * path); GGML_API ggml_backend_reg_t ggml_backend_load(const char * path);
// Unload a backend if loaded dynamically and unregister it
GGML_API void ggml_backend_unload(ggml_backend_reg_t reg);
// Load all known backends from dynamic libraries // Load all known backends from dynamic libraries
GGML_API void ggml_backend_load_all(void); GGML_API void ggml_backend_load_all(void);

View file

@ -1,6 +1,7 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include <algorithm>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
@ -45,8 +46,13 @@
#include "ggml-kompute.h" #include "ggml-kompute.h"
#endif #endif
struct ggml_backend_reg_entry {
ggml_backend_reg_t reg;
void * handle;
};
struct ggml_backend_registry { struct ggml_backend_registry {
std::vector<ggml_backend_reg_t> backends; std::vector<ggml_backend_reg_entry> backends;
std::vector<ggml_backend_dev_t> devices; std::vector<ggml_backend_dev_t> devices;
ggml_backend_registry() { ggml_backend_registry() {
@ -82,7 +88,13 @@ struct ggml_backend_registry {
#endif #endif
} }
void register_backend(ggml_backend_reg_t reg) { ~ggml_backend_registry() {
while (!backends.empty()) {
ggml_backend_unload(backends.back().reg);
}
}
void register_backend(ggml_backend_reg_t reg, void * handle = nullptr) {
if (!reg) { if (!reg) {
return; return;
} }
@ -91,7 +103,7 @@ struct ggml_backend_registry {
GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
__func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
#endif #endif
backends.push_back(reg); backends.push_back({ reg, handle });
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
register_device(ggml_backend_reg_dev_get(reg, i)); register_device(ggml_backend_reg_dev_get(reg, i));
} }
@ -126,7 +138,7 @@ size_t ggml_backend_reg_count() {
ggml_backend_reg_t ggml_backend_reg_get(size_t index) { ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
GGML_ASSERT(index < ggml_backend_reg_count()); GGML_ASSERT(index < ggml_backend_reg_count());
return get_reg().backends[index]; return get_reg().backends[index].reg;
} }
ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) { ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
@ -136,7 +148,7 @@ ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
return reg; return reg;
} }
} }
return NULL; return nullptr;
} }
// Device enumeration // Device enumeration
@ -156,7 +168,7 @@ ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
return dev; return dev;
} }
} }
return NULL; return nullptr;
} }
ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) { ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
@ -166,14 +178,14 @@ ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
return dev; return dev;
} }
} }
return NULL; return nullptr;
} }
// Convenience functions // Convenience functions
ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) { ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {
ggml_backend_dev_t dev = ggml_backend_dev_by_name(name); ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);
if (!dev) { if (!dev) {
return NULL; return nullptr;
} }
return ggml_backend_dev_init(dev, params); return ggml_backend_dev_init(dev, params);
} }
@ -181,7 +193,7 @@ ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params)
ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) { ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {
ggml_backend_dev_t dev = ggml_backend_dev_by_type(type); ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);
if (!dev) { if (!dev) {
return NULL; return nullptr;
} }
return ggml_backend_dev_init(dev, params); return ggml_backend_dev_init(dev, params);
} }
@ -192,9 +204,9 @@ ggml_backend_t ggml_backend_init_best(void) {
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
} }
if (!dev) { if (!dev) {
return NULL; return nullptr;
} }
return ggml_backend_dev_init(dev, NULL); return ggml_backend_dev_init(dev, nullptr);
} }
#ifdef _WIN32 #ifdef _WIN32
@ -214,45 +226,66 @@ ggml_backend_reg_t ggml_backend_load(const char * path) {
HMODULE handle = LoadLibraryA(path); HMODULE handle = LoadLibraryA(path);
if (!handle) { if (!handle) {
GGML_LOG_ERROR("%s: failed to load %s: %lu\n", __func__, path, GetLastError()); GGML_LOG_ERROR("%s: failed to load %s: %lu\n", __func__, path, GetLastError());
return NULL; return nullptr;
} }
ggml_backend_init_t backend_init = (ggml_backend_init_t) GetProcAddress(handle, "ggml_backend_init"); ggml_backend_init_t backend_init = (ggml_backend_init_t) GetProcAddress(handle, "ggml_backend_init");
if (!backend_init) { if (!backend_init) {
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s: %lu\n", __func__, path, GetLastError()); GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s: %lu\n", __func__, path, GetLastError());
FreeLibrary(handle); FreeLibrary(handle);
return NULL; return nullptr;
} }
ggml_backend_reg_t reg = backend_init();
if (!reg) {
GGML_LOG_ERROR("%s: failed to initialize backend from %s\n", __func__, path);
FreeLibrary(handle);
return NULL;
}
GGML_LOG_DEBUG("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path);
ggml_backend_register(reg);
return reg;
#else #else
void * handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); void * handle = dlopen(path, RTLD_NOW | RTLD_LOCAL);
if (!handle) { if (!handle) {
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path, dlerror()); GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path, dlerror());
return NULL; return nullptr;
} }
auto * backend_init = (ggml_backend_init_t) dlsym(handle, "ggml_backend_init"); auto * backend_init = (ggml_backend_init_t) dlsym(handle, "ggml_backend_init");
if (!backend_init) { if (!backend_init) {
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s: %s\n", __func__, path, dlerror()); GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s: %s\n", __func__, path, dlerror());
dlclose(handle); dlclose(handle);
return NULL; return nullptr;
} }
#endif
ggml_backend_reg_t reg = backend_init(); ggml_backend_reg_t reg = backend_init();
if (!reg) { if (!reg) {
GGML_LOG_ERROR("%s: failed to initialize backend from %s\n", __func__, path); GGML_LOG_ERROR("%s: failed to initialize backend from %s\n", __func__, path);
dlclose(handle); dlclose(handle);
return NULL; return nullptr;
} }
GGML_LOG_DEBUG("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path); GGML_LOG_DEBUG("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path);
ggml_backend_register(reg); get_reg().register_backend(reg, handle);
return reg; return reg;
}
void ggml_backend_unload(ggml_backend_reg_t reg) {
auto it = std::find_if(get_reg().backends.begin(), get_reg().backends.end(),
[reg](ggml_backend_reg_entry entry) { return entry.reg == reg; });
if (it == get_reg().backends.end()) {
GGML_LOG_ERROR("%s: backend not found\n", __func__);
return;
}
GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, ggml_backend_reg_name(reg));
// remove devices
get_reg().devices.erase(
std::remove_if(get_reg().devices.begin(), get_reg().devices.end(),
[reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }),
get_reg().devices.end());
// unload library
if (it->handle) {
#ifdef _WIN32
FreeLibrary((HMODULE) it->handle);
#else
dlclose(it->handle);
#endif #endif
}
// remove backend
get_reg().backends.erase(it);
} }
void ggml_backend_load_all() { void ggml_backend_load_all() {