Multi GPU support, CUDA refactor, CUDA scratch buffer (#1703)

* CUDA multi GPU + scratch

ggml_cuda_compute_forward

Tensor parallelism

ggml_cuda_add

ggml_cuda_rms_norm

ggml_cuda_silu

CUDA scratch buffer

--main-gpu CLI option
This commit is contained in:
Johannes Gäßler 2023-06-06 21:33:23 +02:00 committed by GitHub
parent 44f906e853
commit 17366df842
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 1221 additions and 544 deletions

View file

@ -9,6 +9,7 @@
#include <algorithm>
#include <sstream>
#include <unordered_set>
#include <regex>
#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
@ -295,6 +296,40 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else if (arg == "--main-gpu" || arg == "-mg") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
params.main_gpu = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
#endif
} else if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
std::string arg_next = argv[i];
// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
if (i < split_arg.size()) {
params.tensor_split[i] = std::stof(split_arg[i]);
} else {
params.tensor_split[i] = 0.0f;
}
}
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
} else if (arg == "--mtest") {
@ -438,6 +473,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " number of layers to store in VRAM\n");
fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n");
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
#endif
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
@ -483,7 +521,10 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float));
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;