llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row)
This commit is contained in:
parent
87c8207a04
commit
5e879c9977
5 changed files with 62 additions and 21 deletions
|
@ -556,6 +556,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
#else
|
#else
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
|
||||||
#endif
|
#endif
|
||||||
|
} else if (arg == "--split-mode" || arg == "-sm") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::string arg_next = argv[i];
|
||||||
|
if (arg_next == "none") {
|
||||||
|
params.split_mode = LLAMA_SPLIT_NONE;
|
||||||
|
} else if (arg_next == "layer") {
|
||||||
|
params.split_mode = LLAMA_SPLIT_LAYER;
|
||||||
|
} else if (arg_next == "row") {
|
||||||
|
params.split_mode = LLAMA_SPLIT_ROW;
|
||||||
|
} else {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#ifndef GGML_USE_CUBLAS
|
||||||
|
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
|
||||||
|
#endif // GGML_USE_CUBLAS
|
||||||
} else if (arg == "--tensor-split" || arg == "-ts") {
|
} else if (arg == "--tensor-split" || arg == "-ts") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -895,14 +914,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" number of layers to store in VRAM\n");
|
printf(" number of layers to store in VRAM\n");
|
||||||
printf(" -ngld N, --n-gpu-layers-draft N\n");
|
printf(" -ngld N, --n-gpu-layers-draft N\n");
|
||||||
printf(" number of layers to store in VRAM for the draft model\n");
|
printf(" number of layers to store in VRAM for the draft model\n");
|
||||||
|
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
|
||||||
|
printf(" how to split the model across multiple GPUs, one of:\n");
|
||||||
|
printf(" - none: use one GPU only\n");
|
||||||
|
printf(" - layer (default): split layers and KV across GPUs\n");
|
||||||
|
printf(" - row: split rows across GPUs\n");
|
||||||
printf(" -ts SPLIT --tensor-split SPLIT\n");
|
printf(" -ts SPLIT --tensor-split SPLIT\n");
|
||||||
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
|
||||||
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
|
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
|
||||||
#ifdef GGML_USE_CUBLAS
|
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
|
||||||
printf(" -nommq, --no-mul-mat-q\n");
|
|
||||||
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
|
|
||||||
printf(" Not recommended since this is both slower and uses more VRAM.\n");
|
|
||||||
#endif // GGML_USE_CUBLAS
|
|
||||||
#endif
|
#endif
|
||||||
printf(" --verbose-prompt print prompt before generation\n");
|
printf(" --verbose-prompt print prompt before generation\n");
|
||||||
printf(" -dkvc, --dump-kv-cache\n");
|
printf(" -dkvc, --dump-kv-cache\n");
|
||||||
|
@ -1015,6 +1035,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
}
|
}
|
||||||
mparams.main_gpu = params.main_gpu;
|
mparams.main_gpu = params.main_gpu;
|
||||||
|
mparams.split_mode = params.split_mode;
|
||||||
mparams.tensor_split = params.tensor_split;
|
mparams.tensor_split = params.tensor_split;
|
||||||
mparams.use_mmap = params.use_mmap;
|
mparams.use_mmap = params.use_mmap;
|
||||||
mparams.use_mlock = params.use_mlock;
|
mparams.use_mlock = params.use_mlock;
|
||||||
|
|
|
@ -59,6 +59,7 @@ struct gpt_params {
|
||||||
float p_split = 0.1f; // speculative decoding split probability
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
|
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
||||||
|
|
|
@ -1394,6 +1394,8 @@ ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched,
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
||||||
|
// FIXME: node_allocr is cleared when splitting the graph, so all user assignments are lost
|
||||||
|
// to avoid this, we need to clear node_allocr after compute rather than before split
|
||||||
int backend_index = sched_backend_prio(sched, backend);
|
int backend_index = sched_backend_prio(sched, backend);
|
||||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||||
node_allocr(node) = sched->tallocs[backend_index];
|
node_allocr(node) = sched->tallocs[backend_index];
|
||||||
|
|
30
llama.cpp
30
llama.cpp
|
@ -3130,6 +3130,7 @@ static bool llm_load_tensors(
|
||||||
llama_model_loader & ml,
|
llama_model_loader & ml,
|
||||||
llama_model & model,
|
llama_model & model,
|
||||||
int n_gpu_layers,
|
int n_gpu_layers,
|
||||||
|
enum llama_split_mode split_mode,
|
||||||
int main_gpu,
|
int main_gpu,
|
||||||
const float * tensor_split,
|
const float * tensor_split,
|
||||||
bool use_mlock,
|
bool use_mlock,
|
||||||
|
@ -3144,14 +3145,6 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
size_t ctx_size = ggml_tensor_overhead()*ml.n_tensors;
|
size_t ctx_size = ggml_tensor_overhead()*ml.n_tensors;
|
||||||
|
|
||||||
// TODO: user configurable
|
|
||||||
enum gpu_split_mode {
|
|
||||||
LLAMA_SPLIT_NONE, // single GPU
|
|
||||||
LLAMA_SPLIT_LAYER, // offload layers to different GPUs
|
|
||||||
LLAMA_SPLIT_ROW // split matrix rows across GPUs
|
|
||||||
};
|
|
||||||
|
|
||||||
gpu_split_mode split_mode = LLAMA_SPLIT_LAYER;
|
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int64_t n_layer = hparams.n_layer;
|
||||||
const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
|
const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
|
||||||
|
|
||||||
|
@ -3207,13 +3200,25 @@ static bool llm_load_tensors(
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
{
|
{
|
||||||
// offload layers
|
ggml_backend_buffer_type_t split_buft;
|
||||||
|
if (split_mode == LLAMA_SPLIT_ROW) {
|
||||||
|
split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
|
||||||
|
} else {
|
||||||
|
split_buft = llama_default_buffer_type_offload(main_gpu);
|
||||||
|
}
|
||||||
|
// repeating layers
|
||||||
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
|
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
|
||||||
model.buft_layer[i] = { llama_default_buffer_type_split(main_gpu, tensor_split), llama_default_buffer_type_offload(main_gpu) };
|
model.buft_layer[i] = {
|
||||||
|
split_buft,
|
||||||
|
llama_default_buffer_type_offload(main_gpu)
|
||||||
|
};
|
||||||
}
|
}
|
||||||
// output layer
|
// output layer
|
||||||
if (n_gpu_layers > n_layer) {
|
if (n_gpu_layers > n_layer) {
|
||||||
model.buft_output = { llama_default_buffer_type_split(main_gpu, tensor_split), llama_default_buffer_type_offload(main_gpu) };
|
model.buft_output = {
|
||||||
|
split_buft,
|
||||||
|
llama_default_buffer_type_offload(main_gpu)
|
||||||
|
};
|
||||||
} else {
|
} else {
|
||||||
model.buft_output = llama_default_buffer_type_cpu(true);
|
model.buft_output = llama_default_buffer_type_cpu(true);
|
||||||
}
|
}
|
||||||
|
@ -3804,7 +3809,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, cons
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llm_load_tensors(
|
if (!llm_load_tensors(
|
||||||
ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
|
ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock,
|
||||||
params.progress_callback, params.progress_callback_user_data
|
params.progress_callback, params.progress_callback_user_data
|
||||||
)) {
|
)) {
|
||||||
return -2;
|
return -2;
|
||||||
|
@ -8964,6 +8969,7 @@ static int llama_apply_lora_from_file_internal(
|
||||||
struct llama_model_params llama_model_default_params() {
|
struct llama_model_params llama_model_default_params() {
|
||||||
struct llama_model_params result = {
|
struct llama_model_params result = {
|
||||||
/*.n_gpu_layers =*/ 0,
|
/*.n_gpu_layers =*/ 0,
|
||||||
|
/*.split_mode =*/ LLAMA_SPLIT_LAYER,
|
||||||
/*.main_gpu =*/ 0,
|
/*.main_gpu =*/ 0,
|
||||||
/*.tensor_split =*/ nullptr,
|
/*.tensor_split =*/ nullptr,
|
||||||
/*.progress_callback =*/ nullptr,
|
/*.progress_callback =*/ nullptr,
|
||||||
|
|
15
llama.h
15
llama.h
|
@ -115,6 +115,12 @@ extern "C" {
|
||||||
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_split_mode {
|
||||||
|
LLAMA_SPLIT_NONE = 0, // single GPU
|
||||||
|
LLAMA_SPLIT_LAYER = 1, // split layers and KV to different GPUs
|
||||||
|
LLAMA_SPLIT_ROW = 2, // split rows across GPUs
|
||||||
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
|
@ -177,8 +183,13 @@ extern "C" {
|
||||||
|
|
||||||
struct llama_model_params {
|
struct llama_model_params {
|
||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
// the GPU that is used for the model (LLAMA_SPLIT_NONE),
|
||||||
|
// for small tensors and intermediate results (LLAMA_SPLIT_ROW)
|
||||||
|
// ignored for LLAMA_SPLIT_LAYER
|
||||||
|
int32_t main_gpu;
|
||||||
|
// fraction of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
|
||||||
|
const float * tensor_split;
|
||||||
|
|
||||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||||
// If the provided progress_callback returns true, model loading continues.
|
// If the provided progress_callback returns true, model loading continues.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue