rebuild buft list on every call

This commit is contained in:
slaren 2025-02-09 00:32:52 +01:00
parent 538f60934a
commit 8770ffa60c

View file

@ -1485,13 +1485,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg( add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...", {"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) { "override tensor buffer type", [](common_params & params, const std::string & value) {
static std::map<std::string, ggml_backend_buffer_type_t> buft_list; /* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) { if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list // enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i); auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev); auto * buft = ggml_backend_dev_buffer_type(dev);
buft_list[ggml_backend_buft_name(buft)] = buft; if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
} }
} }