llama : add option to override tensor buffers

This commit is contained in:
slaren 2025-01-24 20:56:09 +01:00
parent 9fbadaef4f
commit f07c2ec505
9 changed files with 87 additions and 8 deletions

View file

@ -1,5 +1,6 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "sampling.h"
@ -321,6 +322,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.kv_overrides.back().key[0] = 0;
}
if (!params.tensor_buft_overrides.empty()) {
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}
@ -1477,6 +1482,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
exit(0);
}
));
add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) {
static std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
// FIXME: this leaks memory
params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
}
}
));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",