add --rpc-layers flag to explicitly set RPC layers

The current setup does not allow for very precise control of how many layers to put on the local GPUs vs the remote RPC connected server(s). This adds an additional --rpc-layers flag (-nrl) which allows the user to explicitly set the number of layers to offload to RPC end.
This commit is contained in:
Karl-Johan Alm 2025-02-03 11:24:29 +09:00
parent 325afb370a
commit 7e3e0d98a0
No known key found for this signature in database
GPG key ID: CF78C98086AB1ECA
5 changed files with 30 additions and 1 deletions

View file

@ -1489,6 +1489,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(common_arg(
{"-nrl", "--rpc-layers", "--n-rpc-layers"}, "N",
"number of layers to store on remote RPC devices",
[](common_params & params, int value) {
params.n_rpc_layers = value;
}
).set_env("LLAMA_ARG_N_RPC_LAYERS"));
add_opt(common_arg(
{"-sm", "--split-mode"}, "{none,layer,row}",
"how to split the model across multiple GPUs, one of:\n"

View file

@ -1086,6 +1086,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
if (params.n_rpc_layers != -1) {
mparams.n_rpc_layers = params.n_rpc_layers;
}
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;

View file

@ -217,6 +217,7 @@ struct common_params {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_rpc_layers = -1; // number of layers to store on RPC devices (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs

View file

@ -280,6 +280,7 @@ extern "C" {
ggml_backend_dev_t * devices;
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t n_rpc_layers; // number of layers to delegate to RPC connected devices
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE

View file

@ -1256,6 +1256,7 @@ void llama_model::load_vocab(llama_model_loader & ml) {
bool llama_model::load_tensors(llama_model_loader & ml) {
const auto & split_mode = params.split_mode;
const auto & n_gpu_layers = params.n_gpu_layers;
const auto & n_rpc_layers = params.n_rpc_layers;
const auto & use_mlock = params.use_mlock;
const auto & tensor_split = params.tensor_split;
@ -1263,9 +1264,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const bool use_mmap_buffer = true;
ggml_backend_dev_t rpc_dev = nullptr;
// build a list of buffer types for the CPU and GPU devices
pimpl->cpu_buft_list = make_cpu_buft_list(devices);
for (auto * dev : devices) {
if (n_rpc_layers > 0 && rpc_dev == nullptr && std::string::npos != std::string(ggml_backend_dev_name(dev)).find("RPC[")) {
rpc_dev = dev;
}
buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
// add CPU buffer types as a fallback
buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end());
@ -1279,6 +1285,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// default split, by free memory
for (size_t i = 0; i < n_devices(); ++i) {
ggml_backend_dev_t dev = devices[i];
if (dev == rpc_dev) {
// handled separately
splits[i] = 0;
continue;
}
size_t total;
size_t free;
ggml_backend_dev_memory(dev, &free, &total);
@ -1300,12 +1311,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
const int i_rpc_start = std::max(i_gpu_start - n_rpc_layers, (int) 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
if (il < i_rpc_start || (il - i_gpu_start) >= act_gpu_layers) {
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev));
return {cpu_dev, &pimpl->cpu_buft_list};
}
if (il < i_gpu_start) {
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(rpc_dev));
return {rpc_dev, &pimpl->gpu_buft_list.at(rpc_dev)};
}
const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
auto * dev = devices.at(layer_gpu);
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev));
@ -3760,6 +3776,7 @@ struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.devices =*/ nullptr,
/*.n_gpu_layers =*/ 0,
/*.n_rpc_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,