diff --git a/CMakeLists.txt b/CMakeLists.txt index 3333ee1c9..bff18fa10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,7 +431,7 @@ if (LLAMA_MPI) message(STATUS "MPI found") set(GGML_HEADERS_MPI ggml-mpi.h) - set(GGML_SOURCES_MPI ggml-mpi.c) + set(GGML_SOURCES_MPI ggml-mpi.cpp) add_compile_definitions(GGML_USE_MPI) add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) diff --git a/Makefile b/Makefile index fa112e708..2be80ccc2 100644 --- a/Makefile +++ b/Makefile @@ -577,8 +577,8 @@ endif endif # LLAMA_METAL ifdef LLAMA_MPI -ggml-mpi.o: ggml-mpi.c ggml-mpi.h - $(CC) $(CFLAGS) -c $< -o $@ +ggml-mpi.o: ggml-mpi.cpp ggml-mpi.h + $(CXX) $(CXXFLAGS) -c $< -o $@ endif # LLAMA_MPI GF_CC := $(CC) diff --git a/common/common.cpp b/common/common.cpp index 0cc4859f1..dd88611b4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -170,20 +170,39 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, invalid_param = true; return true; } - params.n_threads = std::stoi(argv[i]); - if (params.n_threads <= 0) { - params.n_threads = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads[node] = std::stoi(split_arg[node]); + if (params.n_threads[node] <= 0) { + params.n_threads[node] = std::thread::hardware_concurrency(); + } } return true; + } if (arg == "-tb" || arg == "--threads-batch") { if (++i >= argc) { invalid_param = true; return true; } - params.n_threads_batch = std::stoi(argv[i]); - if (params.n_threads_batch <= 0) { - params.n_threads_batch = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads_batch.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads_batch[node] = std::stoi(split_arg[node]); + if (params.n_threads_batch[node] <= 0) { + params.n_threads_batch[node] = std::thread::hardware_concurrency(); + } } return true; } @@ -192,9 +211,18 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, invalid_param = true; return true; } - params.n_threads_draft = std::stoi(argv[i]); - if (params.n_threads_draft <= 0) { - params.n_threads_draft = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads_draft.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads_draft[node] = std::stoi(split_arg[node]); + if (params.n_threads_draft[node] <= 0) { + params.n_threads_draft[node] = std::thread::hardware_concurrency(); + } } return true; } @@ -203,9 +231,18 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, invalid_param = true; return true; } - params.n_threads_batch_draft = std::stoi(argv[i]); - if (params.n_threads_batch_draft <= 0) { - params.n_threads_batch_draft = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads_batch_draft.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads_batch_draft[node] = std::stoi(split_arg[node]); + if (params.n_threads_batch_draft[node] <= 0) { + params.n_threads_batch_draft[node] = std::thread::hardware_concurrency(); + } } return true; } @@ -891,6 +928,24 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, #endif // GGML_USE_CUBLAS_SYCL return true; } + if (arg == "--mpi-layer-split") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.mpi_layer_split.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.mpi_layer_split[node] = std::stof(split_arg[node]); + } + return true; + } + if (arg == "--tensor-split" || arg == "-ts") { if (++i >= argc) { invalid_param = true; @@ -1275,7 +1330,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" (can be specified more than once for multiple prompts).\n"); printf(" --color colorise output to distinguish prompt and user input from generations\n"); printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); + printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads[0]); printf(" -tb N, --threads-batch N\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); printf(" -td N, --threads-draft N"); @@ -1393,6 +1448,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu); } +#ifdef GGML_USE_MPI + printf(" --mpi-layer-split N percentiles to split the layers by across nodes\n"); +#endif printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false"); printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false"); printf(" -gan N, --grp-attn-n N\n"); @@ -1443,9 +1501,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { std::string get_system_info(const gpt_params & params) { std::ostringstream os; - os << "system_info: n_threads = " << params.n_threads; - if (params.n_threads_batch != -1) { - os << " (n_threads_batch = " << params.n_threads_batch << ")"; + os << "system_info: n_threads = " << params.n_threads[0]; + if (params.n_threads_batch[0] != -1) { + os << " (n_threads_batch = " << params.n_threads_batch[0] << ")"; } os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); @@ -1590,6 +1648,10 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.kv_overrides = params.kv_overrides.data(); } + free((void *) mparams.node_layer_weights); + + mparams.node_layer_weights = params.mpi_layer_split.data(); + return mparams; } @@ -1629,8 +1691,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_seq_max = params.n_parallel; cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + cparams.n_threads = params.n_threads[0]; + cparams.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0]; cparams.seed = params.seed; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; @@ -1916,6 +1978,7 @@ struct llama_model * llama_load_model_from_hf( #endif // LLAMA_USE_CURL std::tuple llama_init_from_gpt_params(gpt_params & params) { + int32_t n_threads = params.n_threads[0]; auto mparams = llama_model_params_from_gpt_params(params); llama_model * model = nullptr; @@ -1942,6 +2005,16 @@ std::tuple llama_init_from_gpt_par return std::make_tuple(nullptr, nullptr); } +#ifdef GGML_USE_MPI + int node_id = llama_node_id(lctx); + n_threads = (node_id >= params.n_threads.size()) ? get_num_physical_cores() : params.n_threads[node_id]; + int32_t n_threads_batch = (node_id >= params.n_threads_batch.size()) ? -1 : params.n_threads_batch[node_id]; + + params.n_threads[0] = n_threads; // So we can treat index 0 as what our n_threads is elsewhere + params.n_threads_batch[0] = n_threads_batch; + llama_set_n_threads(lctx, n_threads, (n_threads_batch > 0) ? n_threads_batch : get_num_physical_cores()); +#endif + if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); @@ -1975,7 +2048,7 @@ std::tuple llama_init_from_gpt_par ((i > 0) || params.lora_base.empty()) ? NULL : params.lora_base.c_str(), - params.n_threads); + n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); llama_free(lctx); @@ -1991,10 +2064,17 @@ std::tuple llama_init_from_gpt_par { LOG("warming up the model with an empty run\n"); +#ifndef GGML_USE_MPI + // When using MPI, llama_decode() enters into an infinite loop + // on non-head nodes. Thus, we only want to warmup the model here + // if we aren't using MPI. + // FIXME have a way to terminate the infinite loop so we can warmup the model + // in MPI mode std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); +#endif llama_reset_timings(lctx); } @@ -2385,7 +2465,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "threads: %d # default: %u\n", params.n_threads[0], std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); diff --git a/common/common.h b/common/common.h index d827d4df7..9849a6566 100644 --- a/common/common.h +++ b/common/common.h @@ -47,11 +47,10 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - - int32_t n_threads = get_num_physical_cores(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_threads_batch_draft = -1; + std::vector n_threads = {get_num_physical_cores()}; + std::vector n_threads_batch = {-1}; // number of threads to use for batch processing (-1 = use n_threads) + std::vector n_threads_draft = {get_num_physical_cores()}; + std::vector n_threads_batch_draft = {-1}; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) @@ -65,6 +64,7 @@ struct gpt_params { int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs + std::vector mpi_layer_split = {1.0}; // list of percentages of the total number of layers int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs int32_t n_beams = 0; // if non-zero then use beam search of given width. diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 19674dfd3..df9f95bb0 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -102,8 +102,8 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = n_kv_max; ctx_params.n_batch = 512; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + ctx_params.n_threads = params.n_threads[0]; + ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0]; // ensure enough sequences are available ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end()); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 7aaf63ceb..e2d004373 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -83,8 +83,8 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_len, n_parallel); ctx_params.n_seq_max = n_parallel; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + ctx_params.n_threads = params.n_threads[0]; + ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0]; llama_context * ctx = llama_new_context_with_model(model, ctx_params); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index e29da6cb2..3dd1013ea 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -125,14 +125,14 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para if (!params->image.empty()) { fprintf(stderr, "using base64 encoded image instead of command line image path\n"); } - embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt); + embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads[0], prompt); if (!embed) { fprintf(stderr, "%s: can't load image from prompt\n", __func__); return NULL; } params->prompt = remove_image_from_prompt(prompt); } else { - embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str()); + embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads[0], params->image.c_str()); if (!embed) { fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str()); return NULL; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e2d07a631..b4b3f8a6c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -207,6 +207,8 @@ int main(int argc, char ** argv) { return 1; } + llama_split_layers_weighted(ctx, params.mpi_layer_split.data(), params.mpi_layer_split.size()); + const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); LOG("n_ctx: %d\n", n_ctx); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 2cbc9e1fa..7fafc9e0b 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -94,8 +94,8 @@ int main(int argc, char ** argv) { ctx_params.seed = seed; ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; ctx_params.n_batch = 512; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + ctx_params.n_threads = params.n_threads[0]; + ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0]; GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 27bd2dd70..08177726b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2421,7 +2421,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - params.n_threads = std::stoi(argv[i]); + params.n_threads[0] = std::stoi(argv[i]); } else if (arg == "--grp-attn-n" || arg == "-gan") { if (++i >= argc) { invalid_param = true; @@ -2441,7 +2441,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - params.n_threads_batch = std::stoi(argv[i]); + params.n_threads_batch[0] = std::stoi(argv[i]); } else if (arg == "--threads-http") { if (++i >= argc) { invalid_param = true; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 39e2d8ea4..b31d77fbe 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -53,8 +53,8 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = 2048; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + ctx_params.n_threads = params.n_threads[0]; + ctx_params.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0]; llama_context * ctx = llama_new_context_with_model(model, ctx_params); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e991b8846..e447c9949 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -71,7 +71,7 @@ int main(int argc, char ** argv) { // load the draft model params.model = params.model_draft; params.n_gpu_layers = params.n_gpu_layers_draft; - if (params.n_threads_draft > 0) { + if (params.n_threads_draft.size() > 0) { params.n_threads = params.n_threads_draft; } params.n_threads_batch = params.n_threads_batch_draft; diff --git a/ggml-backend.c b/ggml-backend.c index 6026570ae..7e07d5d71 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1719,7 +1719,7 @@ ggml_backend_sched_t ggml_backend_sched_new( bool parallel) { GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); - GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU + GGML_ASSERT(ggml_backend_buft_is_host(ggml_backend_get_default_buffer_type(backends[n_backends - 1]))); // last backend must be host struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1); diff --git a/ggml-mpi.c b/ggml-mpi.c deleted file mode 100644 index ae176d707..000000000 --- a/ggml-mpi.c +++ /dev/null @@ -1,216 +0,0 @@ -#include "ggml-mpi.h" - -#include "ggml.h" - -#include - -#include -#include - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) - -#define UNUSED GGML_UNUSED - -struct ggml_mpi_context { - int rank; - int size; -}; - -void ggml_mpi_backend_init(void) { - MPI_Init(NULL, NULL); -} - -void ggml_mpi_backend_free(void) { - MPI_Finalize(); -} - -struct ggml_mpi_context * ggml_mpi_init(void) { - struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context)); - - MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); - MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); - - return ctx; -} - -void ggml_mpi_free(struct ggml_mpi_context * ctx) { - free(ctx); -} - -int ggml_mpi_rank(struct ggml_mpi_context * ctx) { - return ctx->rank; -} - -void ggml_mpi_eval_init( - struct ggml_mpi_context * ctx_mpi, - int * n_tokens, - int * n_past, - int * n_threads) { - UNUSED(ctx_mpi); - - // synchronize the worker node parameters with the root node - MPI_Barrier(MPI_COMM_WORLD); - - MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); -} - -static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { - struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); - if (t == NULL) { - fprintf(stderr, "%s: tensor %s not found\n", __func__, name); - return -1; - } - - for (int i = 0; i < gf->n_nodes; i++) { - if (gf->nodes[i] == t) { - return i; - } - } - - fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name); - return -1; -} - -static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) { - MPI_Datatype mpi_type; - - switch (t->type) { - case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; - case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; - default: GGML_ASSERT(false && "not implemented"); - } - - const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD); - GGML_ASSERT(retval == MPI_SUCCESS); -} - -static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { - MPI_Datatype mpi_type; - - switch (t->type) { - case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; - case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; - default: GGML_ASSERT(false && "not implemented"); - } - - MPI_Status status; UNUSED(status); - - const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - GGML_ASSERT(retval == MPI_SUCCESS); -} - -// TODO: there are many improvements that can be done to this implementation -void ggml_mpi_graph_compute_pre( - struct ggml_mpi_context * ctx_mpi, - struct ggml_cgraph * gf, - int n_layers) { - const int mpi_rank = ctx_mpi->rank; - const int mpi_size = ctx_mpi->size; - - struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens"); - if (inp_tokens == NULL) { - fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__); - return; - } - - struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0"); - if (inp0 == NULL) { - fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__); - return; - } - - GGML_ASSERT(inp0 == gf->nodes[0]); - - // distribute the compute graph into slices across the MPI nodes - // - // the main node (0) processes the last layers + the remainder of the compute graph - // and is responsible to pass the input tokens to the first node (1) - // - // node 1: [( 0) * n_per_node, ( 1) * n_per_node) - // node 2: [( 1) * n_per_node, ( 2) * n_per_node) - // ... - // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node) - // node 0: [(n-1) * n_per_node, n_nodes) - // - if (mpi_rank > 0) { - if (mpi_rank == 1) { - // the first node (1) receives the input tokens from the main node (0) - ggml_mpi_tensor_recv(inp_tokens, 0); - } else { - // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) - ggml_mpi_tensor_recv(inp0, mpi_rank - 1); - } - } else if (mpi_size > 1) { - // node 0 sends the input tokens to node 1 - ggml_mpi_tensor_send(inp_tokens, 1); - - // recv the output data from the last node - ggml_mpi_tensor_recv(inp0, mpi_size - 1); - } - - { - const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size; - - const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1; - - const int il0 = (mpi_idx + 0) * n_per_node; - const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node); - - char name_l0[GGML_MAX_NAME]; - char name_l1[GGML_MAX_NAME]; - - snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); - snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); - - const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); - const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes; - - if (idx_l0 < 0 || idx_l1 < 0) { - fprintf(stderr, "%s: layer input nodes not found\n", __func__); - return; - } - - // attach the input data to all nodes that need it - // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) - for (int i = idx_l0; i < idx_l1; i++) { - if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) { - gf->nodes[i]->src[0] = inp0; - } - if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) { - gf->nodes[i]->src[1] = inp0; - } - } - - // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph - for (int i = 1; i < idx_l1 - idx_l0; i++) { - gf->nodes[i] = gf->nodes[idx_l0 + i]; - gf->grads[i] = gf->grads[idx_l0 + i]; - } - - // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node - if (mpi_idx != 0) { - gf->nodes[0]->op = GGML_OP_NONE; - } - - gf->n_nodes = idx_l1 - idx_l0; - - //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); - } -} - -void ggml_mpi_graph_compute_post( - struct ggml_mpi_context * ctx_mpi, - struct ggml_cgraph * gf, - int n_layers) { - UNUSED(n_layers); - - const int mpi_rank = ctx_mpi->rank; - const int mpi_size = ctx_mpi->size; - - // send the output data to the next node - if (mpi_rank > 0) { - ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size); - } -} diff --git a/ggml-mpi.cpp b/ggml-mpi.cpp new file mode 100644 index 000000000..f8c87f2d6 --- /dev/null +++ b/ggml-mpi.cpp @@ -0,0 +1,1044 @@ +#include "ggml-mpi.h" + +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" + +#include + +#include +#include +#include + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +#define UNUSED GGML_UNUSED + +static bool have_init = false; + +static void* send_buffer; + +struct ggml_mpi_context { + int rank; + int size; + MPI_Comm comm; + int layer_start; + int layer_end; + MPI_Status status; + + struct ggml_tensor *inp0; + std::string name; + struct ggml_backend * wrapped_backend; + std::vector backends; + ggml_backend_sched_t scheduler; + bool remote; + void* send_buffer; + int trans_id; + int recv_trans_id; +}; + +void ggml_mpi_backend_init(void) { + int ret; + + GGML_ASSERT(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ret) == MPI_SUCCESS); + have_init = true; + const int buffer_size = 128*1024*1024*8; + send_buffer = calloc(1, buffer_size); // 128MB buffer +// fprintf(stderr, "BUFFER ATTACH RETCODE=%d\n", MPI_Buffer_attach(send_buffer, buffer_size)); +} + +void ggml_mpi_sync_pipelined( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag +); + +void ggml_mpi_backend_free(void) { + MPI_Finalize(); +} + +struct ggml_mpi_context * ggml_mpi_init(void) { + + if (!have_init) { + ggml_mpi_backend_init(); + } + + auto * ctx = new ggml_mpi_context; + + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); + MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); + ctx->comm = MPI_COMM_WORLD; + ctx->remote = false; + + ctx->send_buffer = send_buffer; + + return ctx; +} + +struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int color, int key) { + auto * newCtx = static_cast(calloc(1, sizeof(struct ggml_mpi_context))); + MPI_Comm_split(ctx->comm, color, key, &newCtx->comm); + MPI_Comm_rank(newCtx->comm, &newCtx->rank); + MPI_Comm_size(newCtx->comm, &newCtx->size); + return newCtx; +} + +void ggml_mpi_free(struct ggml_mpi_context * ctx) { + MPI_Comm_free(&(ctx->comm)); + free(ctx); +} + +int ggml_mpi_rank(struct ggml_mpi_context * ctx) { + return ctx->rank; +} + +size_t ggml_mpi_size(struct ggml_mpi_context * ctx) { + return ctx->size; +} + +int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) { + return (ctx_mpi->rank + 1) % ctx_mpi->size; +} + +int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi) { + int temp = (ctx_mpi->rank - 1); + return (temp >= 0) ? temp : ctx_mpi->size - 1; +} + +void ggml_mpi_sync_pipelined( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag +) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + +// printf("Rank %d sync pipelined with tag %d\n", ctx_mpi->rank, tag); + + + if (ctx_mpi->rank != 0) { + MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE); + } + if(ctx_mpi->rank < ctx_mpi->size - 1) { + GGML_ASSERT(ctx_mpi->send_buffer != nullptr); + GGML_ASSERT(val != nullptr || count == 0); + GGML_ASSERT(count < 128*1024*1024); + + const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm); + GGML_ASSERT(retval == MPI_SUCCESS); + + } +} + +void ggml_mpi_barrier(struct ggml_mpi_context * ctx_mpi) { + MPI_Barrier(ctx_mpi->comm); +} + +void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag) { + MPI_Probe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &(ctx_mpi->status)); +} + +int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return 0; + } + + int ret; + MPI_Iprobe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &ret, &(ctx_mpi->status)); + return ret; +} + +int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi) { + return ctx_mpi->status.MPI_TAG; +} + +int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi) { + int32_t count; + MPI_Get_count(&ctx_mpi->status, MPI_INT32_T, &count); + return count; +} + +void ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + int32_t * n_tokens, + int32_t ** pos, + int32_t ** n_seq_ids, + int32_t *** seq_id, + int8_t ** logits, + uint32_t n_seq_max) { + + + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + + int32_t old_n_tokens = *n_tokens; + ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS); + + if (old_n_tokens != *n_tokens) { + *pos = static_cast(realloc(*pos, *n_tokens * sizeof(int32_t))); + *n_seq_ids = static_cast(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t))); + *logits = static_cast(realloc(*logits, *n_tokens * sizeof(int32_t))); + } + + int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t)); + + if (ctx_mpi->rank == 0 && *logits != nullptr) { + ggml_mpi_sync_pipelined(ctx_mpi, *logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS); + } else { + ggml_mpi_sync_pipelined(ctx_mpi, temp_logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS); + } + + + + if (ctx_mpi->rank != 0) { + bool should_set_batch_logits = false; + for (int i = 0; i < *n_tokens; i++) { + if (temp_logits[i]) { + should_set_batch_logits = true; + break; + } + } + if (should_set_batch_logits) { + if (*logits != NULL) { + free(*logits); + *logits = NULL; + } + *logits = temp_logits; + } else { + if (*logits != NULL) { + free(*logits); + *logits = NULL; + } + free(temp_logits); + } + } else { + free(temp_logits); + } + + // For now, we assume that the pos, seq_ids, tokens, etc have been + // pre-allocated for the largest possible sizes, even on worker nodes. + + GGML_ASSERT(n_seq_ids != nullptr); + GGML_ASSERT(*n_seq_ids != nullptr); + + GGML_ASSERT(n_tokens != nullptr); + + + // FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend + ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS); + + // We need to know the total number of sequence + // ids, so we count them all up + int32_t total_n_seq_ids = 0; + for (int32_t i = 0; i < *n_tokens; i++) { + total_n_seq_ids += (*n_seq_ids)[i]; + } + + // MPI can't chase the pointers for multidimensional arrays, so we flatten them first + // for transit + int32_t * flattened_seq_ids = static_cast(calloc(total_n_seq_ids, sizeof(int32_t))); + + int32_t current_index = 0; + + // Only rank 0 needs to flatten since the others don't have the real seq_id + if (ctx_mpi->rank == 0) { + for (int32_t i = 0; i < *n_tokens; i++) { + for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { + flattened_seq_ids[current_index] = (*seq_id)[i][j]; + current_index++; + } + } + } + + + + ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS); + ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS); + + current_index = 0; + for (int32_t i = 0; i < *n_tokens; i++) { + for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { + (*seq_id)[i][j] = flattened_seq_ids[current_index]; + current_index++; + } + + } + free(flattened_seq_ids); +} + + +void ggml_mpi_sync_int( + struct ggml_mpi_context * ctx_mpi, + int32_t * val +) { + MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm); +} + +void ggml_mpi_sync_ints_pipelined( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +) { + ggml_mpi_sync_pipelined(ctx_mpi, vals, count, MPI_INT32_T, tag); + int old_trans = ctx_mpi->trans_id; + ggml_mpi_sync_pipelined(ctx_mpi, &ctx_mpi->trans_id, 1, MPI_INT32_T, GGML_MPI_TRANS_ID); + ctx_mpi->recv_trans_id = ctx_mpi->trans_id; + ctx_mpi->trans_id = old_trans; +} + +static void ggml_mpi_tensor_send(const struct ggml_tensor * t, const void* data, int mpi_rank_dst, MPI_Comm comm) { + MPI_Datatype mpi_type; + +// fprintf(stderr, "Type: %d\n", t->type); + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + case GGML_TYPE_F16: mpi_type = MPI_INT16_T; break; + default: GGML_ASSERT(false && "not implemented"); + } + int rank; + MPI_Comm_rank(comm, &rank); +// fprintf(stderr, "Sending tensor %s (buffer %s) from %d to %d\n", t->name, ggml_backend_buffer_name(t->buffer), rank, mpi_rank_dst); + + const int retval = MPI_Send(data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, comm); + GGML_ASSERT(retval == MPI_SUCCESS); + +} +static void ggml_mpi_tensor_send(const struct ggml_tensor * t, int mpi_rank_dst, MPI_Comm comm) { + ggml_mpi_tensor_send(t, t->data, mpi_rank_dst, comm); +} + +static void ggml_mpi_tensor_recv(const struct ggml_tensor * t, void * data, int mpi_rank_src, MPI_Comm comm) { + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + MPI_Status status; UNUSED(status); +// fprintf(stderr, "%s: tensor receive == null: %d\n", __func__, t->data == NULL); + int rank; + MPI_Comm_rank(comm, &rank); +// fprintf(stderr, "Receiving tensor %s (buffer %s) from %d at %d\n", t->name, ggml_backend_buffer_name(t->buffer), mpi_rank_src, rank); + const int retval = MPI_Recv(data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, comm, &status); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src, MPI_Comm comm) { + ggml_mpi_tensor_recv(t, t->data, mpi_rank_src, comm); +} + +uint16_t** ggml_mpi_split_range( + struct ggml_mpi_context * ctx_mpi, + uint16_t start, + uint16_t end, + const float node_weights[] +) { + // Splits the range given by start and end + // over the available nodes. This implementation + // assumes that node 0 handles the final part of the range + // while node 1 handles the beginning, to form a ring pipeline + + uint16_t range_length = end - start + 1; + uint16_t ** ranges = (uint16_t**) malloc(sizeof(uint16_t*) * ctx_mpi->size); + for (int i = 0; i < ctx_mpi->size; i++) { + ranges[i] = (uint16_t*) malloc(sizeof(uint16_t) * 2); + } + uint16_t next_layer = 0; + for (int i=0; i < ctx_mpi->size; i++) { + ranges[i][0] = next_layer; + ranges[i][1] = MIN(end, ranges[i][0] + (node_weights[i] * range_length) + start); + next_layer = ranges[i][1]; + } + +// ranges[0][0] = next_layer; +// ranges[0][1] = MIN(end, next_layer + (node_weights[0] * range_length) + start); + return ranges; + +} + +// BACKEND V2 + +struct ggml_backend_mpi_buffer_context { + ggml_backend_buffer_t wrapped_buffer; + ggml_mpi_context * ctx_mpi; +}; + +struct ggml_backend_mpi_buffer_type_context { + std::string name; + ggml_backend_buffer_type_t wrapped_buffer_type; + ggml_mpi_context * ctx_mpi; +}; + +int ggml_backend_mpi_buffer_type_rank(ggml_backend_buffer_type_t buft); + +int ggml_backend_mpi_buffer_type_local_rank(ggml_backend_buffer_type_t buft); + +GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft) { + auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context; + + + return strdup( + ( + ctx->name + + " Buffer Type(Rank " + + std::to_string( + ggml_backend_mpi_buffer_type_rank(buft) + ) + + ", local rank " + + std::to_string(ggml_backend_mpi_buffer_type_local_rank(buft)) + + "):" + + std::string( + ctx->wrapped_buffer_type->iface.get_name(ctx->wrapped_buffer_type) + ) + ).c_str() + ); +} + +MPI_Comm ggml_backend_mpi_buffer_type_get_comm(ggml_backend_buffer_type_t buft) { + auto * buft_ctx = (ggml_backend_mpi_buffer_type_context *) buft->context; + return buft_ctx->ctx_mpi->comm; + +} + +MPI_Comm ggml_backend_mpi_buffer_get_comm(ggml_backend_buffer_t buffer) { + return ggml_backend_mpi_buffer_type_get_comm(buffer->buft); +} + +MPI_Comm ggml_backend_mpi_get_comm(ggml_backend_t backend) { + auto * ctx = (ggml_mpi_context *) backend->context; + + return ctx->comm; +} + +int ggml_backend_mpi_buffer_local_rank(ggml_backend_buffer_t buffer) { + int rank; + int ret = MPI_Comm_rank(ggml_backend_mpi_buffer_get_comm(buffer), &rank); + GGML_ASSERT(ret == MPI_SUCCESS); + return rank; +} + +int ggml_backend_mpi_buffer_type_local_rank(ggml_backend_buffer_type_t buft) { + int rank; + int ret = MPI_Comm_rank(ggml_backend_mpi_buffer_type_get_comm(buft), &rank); + GGML_ASSERT(ret == MPI_SUCCESS); + return rank; +} + +int ggml_backend_mpi_local_rank(ggml_backend_t backend) { + int rank; + int ret = MPI_Comm_rank(ggml_backend_mpi_get_comm(backend), &rank); + GGML_ASSERT(ret == MPI_SUCCESS); + return rank; +} + +int ggml_backend_mpi_buffer_rank(ggml_backend_buffer_t buffer) { + auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context; + return ctx->ctx_mpi->rank; +} + +int ggml_backend_mpi_buffer_type_rank(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft->iface.get_name == ggml_backend_mpi_buffer_type_name); + auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context; + GGML_ASSERT(ctx != nullptr); + GGML_ASSERT(ctx->ctx_mpi != nullptr); + return ctx->ctx_mpi->rank; +} + +int ggml_backend_mpi_rank(ggml_backend_t backend) { + auto * ctx = (ggml_mpi_context *) backend->context; + return ctx->rank; +} + +ggml_backend_buffer_t ggml_backend_mpi_buffer_unwrap(ggml_backend_buffer_t buffer) { + auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context; + + ggml_backend_buffer_t wrapped_buffer = ctx->wrapped_buffer; + wrapped_buffer->usage = buffer->usage; + wrapped_buffer->size = buffer->size; + return wrapped_buffer; + +} + +ggml_backend_buffer_type_t ggml_backend_mpi_buffer_type_unwrap(ggml_backend_buffer_type_t buft) { + auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context; + + ggml_backend_buffer_type_t wrapped_buffer_type = ctx->wrapped_buffer_type; + return wrapped_buffer_type; + +} + + +GGML_CALL static const char * ggml_backend_mpi_buffer_name(ggml_backend_buffer_t buffer) { + + + + return strdup( + ( + + "MPI Buffer(Rank " + + std::to_string(ggml_backend_mpi_buffer_rank(buffer)) + + ", local rank " + + std::to_string(ggml_backend_mpi_buffer_local_rank(buffer)) + + "):" + + std::string( + ggml_backend_buffer_name( + ggml_backend_mpi_buffer_unwrap(buffer) + ) + ) + ).c_str() + ); +} + + +GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft); + +GGML_CALL void ggml_backend_mpi_buffer_type_copy_ctx(ggml_backend_buffer_type_t src, ggml_backend_buffer_type_t dst) { + if (src->iface.get_name == ggml_backend_mpi_buffer_type_name) { + *((ggml_backend_mpi_buffer_type_context *) dst->context)->ctx_mpi = *((ggml_backend_mpi_buffer_type_context *) src->context)->ctx_mpi; + } else { + GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type_t"); + } +} + +GGML_CALL void ggml_backend_mpi_buffer_copy_ctx(ggml_backend_buffer_t src, ggml_backend_buffer_t dst) { + if (src->iface.get_name == ggml_backend_mpi_buffer_name) { + *((ggml_backend_mpi_buffer_context *) dst->context)->ctx_mpi = *((ggml_backend_mpi_buffer_context *) src->context)->ctx_mpi; + ggml_backend_mpi_buffer_type_copy_ctx(src->buft, dst->buft); + } else { + GGML_ASSERT(!"Buffer must be wrapped in ggml_backend_mpi_buffer_t"); + } +} + +GGML_CALL void ggml_backend_mpi_buffer_copy_ctx_from_type(ggml_backend_buffer_type_t src, ggml_backend_buffer_t dst) { + if (src->iface.get_name == ggml_backend_mpi_buffer_type_name) { + *((ggml_backend_mpi_buffer_context *) dst->context)->ctx_mpi = *((ggml_backend_mpi_buffer_type_context *) src->context)->ctx_mpi; + ggml_backend_mpi_buffer_type_copy_ctx(src, dst->buft); + } else { + GGML_ASSERT(!"Buffer must be wrapped in ggml_backend_mpi_buffer_t"); + } +} + +GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + + struct ggml_mpi_context * ctx = (ggml_mpi_context *) backend->context; + + std::vector backend_buft; + for (auto *curr_backend: ctx->backends) { + if (ggml_backend_is_cpu(curr_backend)) { + // use host buffers for the CPU backend compute buffer + backend_buft.push_back(ggml_backend_cpu_buffer_type()); + } else { + backend_buft.push_back(ggml_backend_get_default_buffer_type(curr_backend)); + } + } + + + std::vector>> old_buffs( + cgraph->n_nodes); + std::vector old_view_buffs(cgraph->n_nodes); + + + for (int i = 0; i < cgraph->n_nodes; i++) { + old_buffs[i].first = cgraph->nodes[i]->buffer; + + + for (auto &src: cgraph->nodes[i]->src) { + if (src == nullptr) { + break; + } + old_buffs[i].second.push_back(src->buffer); + + } + + auto *src = cgraph->nodes[i]->view_src; + if (src != nullptr) { + if (src->buffer->buft != nullptr) { + old_view_buffs[i] = src->buffer; + + } + } + } + + size_t n_srcs = 0; + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { + cgraph->nodes[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->nodes[i]->buffer); + } + + for (auto &src: cgraph->nodes[i]->src) { + if (src == nullptr) { + break; + } + if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { + n_srcs++; + src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer); + } + } + + auto *src = cgraph->nodes[i]->view_src; + if (src != nullptr) { + if (src->buffer->buft != nullptr) { + + if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { + n_srcs++; + src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer); + } + } + } + } + std::vector old_buffs_leaves; + for (int i = 0; i < cgraph->n_leafs; i++) { + old_buffs_leaves.push_back(cgraph->leafs[i]->buffer->buft); + if (cgraph->leafs[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { + cgraph->leafs[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->leafs[i]->buffer); + } + } + + // TODO exploding memory usage cause we replace the buffer with the wrapped buffer, + // but don't free the contexts, and then create new ones when we re-wrap + + + if (!ctx->remote) { + ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), + (int) ctx->backends.size(), cgraph->n_nodes + cgraph->n_leafs + n_srcs, false); + + ggml_backend_sched_reserve(sched, cgraph); + ggml_backend_sched_graph_compute_async(sched, cgraph); + ggml_backend_sched_free(sched); + + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + cgraph->nodes[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->nodes[i]->buffer); + + ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->buffer, ggml_backend_mpi_buffer_rank(old_buffs[i].first)); + + + for (int iter = 0; iter < GGML_MAX_SRC; iter++) { + auto* src_node = cgraph->nodes[i]->src[iter]; + if (src_node == nullptr) { + break; + } + + if (src_node->buffer->iface.get_name == ggml_backend_mpi_buffer_name) { + continue; + } + + src_node->buffer = ggml_backend_mpi_wrap_buffer(src_node->buffer); + + ggml_backend_mpi_buffer_set_rank(src_node->buffer, ggml_backend_mpi_buffer_rank(old_buffs[i].second[iter])); + } + if(cgraph->nodes[i]->view_src != nullptr && cgraph->nodes[i]->view_src->buffer->buft != nullptr) { + + if (old_view_buffs[i] != nullptr) { + if (old_view_buffs[i]->iface.get_name == ggml_backend_mpi_buffer_name) { + ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->view_src->buffer, + ggml_backend_mpi_buffer_rank(old_view_buffs[i])); + } + } + } + + } + + + // FIXME check if this is correct or not (it's probably not) + for (int i = 0; i < cgraph->n_leafs; i++) { + cgraph->leafs[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->leafs[i]->buffer); + ggml_backend_mpi_buffer_type_set_rank(cgraph->leafs[i]->buffer->buft, ctx->rank); + } + + return GGML_STATUS_SUCCESS; +} + + +static const char * ggml_backend_mpi_name(ggml_backend_t backend) { + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + return strdup(("MPI(Rank " + std::to_string(ggml_backend_mpi_rank(backend)) + ", local rank " + std::to_string(ggml_backend_mpi_local_rank(backend)) + ")").c_str()); +} + +static void ggml_backend_mpi_free(ggml_backend_t backend) { + auto * ctx = static_cast(backend->context); + + delete ctx; + + + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_mpi_get_default_buffer_type(ggml_backend_t backend) { + auto * ctx = static_cast(backend->context); + if (ctx->backends.empty()) { + auto * buff = ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type()); + ggml_backend_mpi_buffer_type_set_rank(buff, ctx->rank); + return buff; + } + + auto * buff = ggml_backend_mpi_wrap_buffer_type(ctx->backends.back()->iface.get_default_buffer_type(ctx->backends.back())); + ggml_backend_mpi_buffer_type_set_rank(buff, ctx->rank); + return buff; +} + +GGML_CALL static bool ggml_backend_mpi_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_CPY: + return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float + case GGML_OP_MUL_MAT: + return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + default: + return true; + } + + GGML_UNUSED(backend); +} + + + +std::vector ggml_mpi_available_devices_internal() { + static bool has_init = false; + if (!has_init) { + ggml_mpi_backend_init(); + has_init = true; + } + std::vector devices; + int s; + MPI_Comm_size(MPI_COMM_WORLD, &s); + devices.resize(s); + for (int i = 0; i < s; i++) { + devices[i] = ggml_mpi_device{ + i, + ggml_mpi_init(), + ("MPI_COMM_WORLD:" + std::to_string(i)).c_str(), + 1 + }; + } + return devices; +} + + + +GGML_CALL bool ggml_backend_is_mpi(ggml_backend_t backend) { + return backend && backend->iface.get_name == ggml_backend_mpi_name; +} + + + + + + +GGML_CALL static ggml_backend_buffer_t ggml_backend_mpi_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + + auto* buffer = ggml_backend_mpi_wrap_buffer( + ggml_backend_buft_alloc_buffer(ggml_backend_mpi_buffer_type_unwrap(buft), size) + ); + + ggml_backend_mpi_buffer_copy_ctx_from_type(buft, buffer); + + return buffer; +} + +GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_get_alignment(ggml_backend_mpi_buffer_type_unwrap(buft)); +} + +GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_get_max_size(ggml_backend_mpi_buffer_type_unwrap(buft)); +} + +GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + // Have to do this instead of calling ggml_backend_type_get_alloc_size because that signature doesn't have const on tensor + return ggml_backend_mpi_buffer_type_unwrap(buft)->iface.get_alloc_size(ggml_backend_mpi_buffer_type_unwrap(buft), tensor); +} + +GGML_CALL static bool ggml_backend_mpi_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return backend != nullptr && ggml_backend_is_mpi(backend) && ggml_backend_mpi_buffer_type_rank(buft) == ggml_backend_mpi_rank(backend); +} + +GGML_CALL static bool ggml_backend_mpi_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + + return ggml_backend_mpi_buffer_type_rank(buft) == ggml_backend_mpi_buffer_type_local_rank(buft) && ggml_backend_buft_is_host(ggml_backend_mpi_buffer_type_unwrap(buft)); +} + + +static std::map cached_wrappers; + +static std::map cached_buffer_wrappers; + +static std::map cached_backends; + +GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_backend_buffer_type_t buft) { + +// if (cached_wrappers.find(buft) != cached_wrappers.end()) { +// fprintf(stderr, "Returning cached buffer type with name %s\n", cached_wrappers[buft]->iface.get_name(cached_wrappers[buft])); +// +// auto * ret = new ggml_backend_buffer_type; +// *ret = *cached_wrappers[buft]; +// return ret; +// } + + ggml_backend_buffer_type_i ggml_backend_mpi_buffer_type_interface = { + /* .get_name = */ ggml_backend_mpi_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_mpi_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_mpi_buffer_type_get_alignment, + /* .get_max_size = */ (buft->iface.get_max_size != nullptr ) ? ggml_backend_mpi_buffer_type_get_max_size : nullptr, + /* .get_alloc_size = */ (buft->iface.get_alloc_size != nullptr ) ? ggml_backend_mpi_buffer_type_get_alloc_size : nullptr, + /* .supports_backend = */ ggml_backend_mpi_buffer_type_supports_backend, + /* .is_host = */ (buft->iface.is_host != nullptr ) ? ggml_backend_mpi_buffer_type_is_host : nullptr, + }; + + + + auto* ggml_backend_wrapped_buffer_type = new ggml_backend_buffer_type { + /* .iface = */ ggml_backend_mpi_buffer_type_interface, + /* .context = */ new ggml_backend_mpi_buffer_type_context{ + /* .name = */ "MPI", + /* .wrapped_buffer_type = */ buft, + /* .ctx_mpi = */ ggml_mpi_init() + } + }; + + // Set rank to 0 as default + ggml_backend_mpi_buffer_type_set_rank(ggml_backend_wrapped_buffer_type, 0); + + cached_wrappers[buft] = ggml_backend_wrapped_buffer_type; + + return ggml_backend_wrapped_buffer_type; +} + + + +GGML_CALL static void * ggml_backend_mpi_buffer_get_base(ggml_backend_buffer_t buffer) { + auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context; + return ctx->wrapped_buffer->iface.get_base(ctx->wrapped_buffer); +} + +GGML_CALL static void ggml_backend_mpi_buffer_free_buffer(ggml_backend_buffer_t buffer) { + auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context; + return ctx->wrapped_buffer->iface.free_buffer(ctx->wrapped_buffer); +} + +GGML_CALL static void ggml_backend_mpi_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context; + + if (ggml_backend_mpi_buffer_rank(buffer) != ggml_backend_mpi_buffer_local_rank(buffer)) { + return; + } + +// fprintf(stderr, "SETTING TENSOR WITHOUT MPI CALLS FOR %s (%s) AND TGT BUFFER %s\n", tensor->name, ggml_backend_buffer_name(tensor->buffer), ggml_backend_buffer_name(buffer)); + ctx->wrapped_buffer->iface.set_tensor(ctx->wrapped_buffer, tensor, data, offset, size); +} + +GGML_CALL static void ggml_backend_mpi_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + int rank = ggml_backend_mpi_buffer_local_rank(tensor->buffer); + + int src_rank = ggml_backend_mpi_buffer_rank(tensor->buffer); + + if (rank != src_rank) { + + ggml_mpi_tensor_recv(tensor, data, ggml_backend_mpi_buffer_rank(tensor->buffer), ggml_backend_mpi_buffer_get_comm(tensor->buffer)); + return; + } + + ggml_backend_mpi_buffer_unwrap(buffer)->iface.get_tensor(ggml_backend_mpi_buffer_unwrap(buffer), tensor, data, offset, size); +} + +GGML_CALL static bool ggml_backend_mpi_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_mpi_buffer_rank(src->buffer) == ggml_backend_mpi_buffer_rank(dst->buffer)) { + return ggml_backend_mpi_buffer_unwrap(buffer)->iface.cpy_tensor(ggml_backend_mpi_buffer_unwrap(buffer), src, + dst); + } + + return true; +} + +GGML_CALL static void ggml_backend_mpi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + return ggml_backend_mpi_buffer_unwrap(buffer)->iface.clear(ggml_backend_mpi_buffer_unwrap(buffer), value); +} + +static struct ggml_backend_buffer_i mpi_backend_buffer_i = { + /* .get_name = */ ggml_backend_mpi_buffer_name, + /* .free_buffer = */ ggml_backend_mpi_buffer_free_buffer, + /* .get_base = */ ggml_backend_mpi_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .set_tensor = */ ggml_backend_mpi_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_mpi_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_mpi_buffer_cpy_tensor, + /* .clear = */ ggml_backend_mpi_buffer_clear, + /* .reset = */ NULL, +}; + +GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf) { + +// if (cached_buffer_wrappers.find(buf) != cached_buffer_wrappers.end()) { +// fprintf(stderr, "Returning cached buffer with name %s\n", cached_buffer_wrappers[buf]->iface.get_name(cached_buffer_wrappers[buf])); +// auto * ret = new ggml_backend_buffer; +// *ret = *cached_buffer_wrappers[buf]; +// auto * ret_type = new ggml_backend_buffer_type; +// *ret_type = *ret->buft; +// ret->buft = ret_type; +// return ret; +// } + + + ggml_backend_buffer_type_t t = ggml_backend_mpi_wrap_buffer_type(buf->buft); + + auto *buffer = new ggml_backend_buffer { + /* .interface = */ mpi_backend_buffer_i, + /* .buft = */ t, + /* .context = */ new ggml_backend_mpi_buffer_context{ + buf, ggml_mpi_init()}, + /* .size = */ buf->size, + /* .usage = */ buf->usage + }; + + // Default to node 0 when wrapping buffers + ggml_backend_mpi_buffer_set_rank(buffer, 0); + + cached_buffer_wrappers[buf] = buffer; + + + + return buffer; +} + +bool ggml_backend_mpi_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { +// int src_rank = ggml_backend_mpi_buffer_rank(src->buffer); +// int dst_rank = ggml_backend_mpi_buffer_rank(dst->buffer); +// +// auto * ctx = static_cast(backend->context); +// +// if (ctx->remote) { +// return true; +// } +// +// if (src_rank == dst_rank) { +//// src->buffer->iface.cpy_tensor(src->buffer, src, dst); +// return true; +// } +// +// if (src_rank == ggml_backend_mpi_local_rank(backend)) { +// ggml_mpi_tensor_send(src, dst_rank, ctx->comm); +// } else if (dst_rank == ggml_backend_mpi_local_rank(backend)){ +// ggml_mpi_tensor_recv(dst, src_rank, ctx->comm); +// } +// fprintf(stderr, "ATTEMPTING ASYNC COPY FOR SRC TENSOR %s TO DST TENSOR %s WITH SRC BACKEND %s AND DST BACKEND %s\n", src->name, dst->name, ggml_backend_name(backend_src), ggml_backend_name(backend_dst)); + return false; + +} + +void ggml_backend_mpi_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * dst, const void* data, size_t offset, size_t size) { + int dst_rank = ggml_backend_mpi_buffer_rank(dst->buffer); + + + auto * ctx = static_cast(backend->context); + + GGML_ASSERT(ctx->rank == dst_rank); + + ggml_mpi_tensor_send(dst, data, ctx->rank, ctx->comm); + + +} + + +ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends, int rank) { + + static ggml_guid backend_mpi_guid = {0xec, 0x39, 0xce, 0x40, 0xc3, 0x43, 0x49, 0x36, 0x96, 0x03, 0x55, 0x77, 0x5c, 0x1f, 0x44, 0xd3}; + + + ggml_mpi_context * ctx = ggml_mpi_init(); + std::vector wrapped_backends_v; + if (ctx->rank == rank) { + for (size_t i = 0; i < num_backends; i++) { + wrapped_backends_v.push_back(wrapped_backends[i]); + } + } else { + ctx->remote = true; + } + ctx->backends = wrapped_backends_v; + ctx->rank = rank; + struct ggml_backend_i mpi_backend_i = { + /* .get_name = */ ggml_backend_mpi_name, + /* .free = */ ggml_backend_mpi_free, + /* .get_default_buffer_type = */ ggml_backend_mpi_get_default_buffer_type, + /* .set_tensor_async = */ ggml_backend_mpi_set_tensor_async, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ ggml_backend_mpi_cpy_tensor_async, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_mpi_graph_compute, + /* .supports_op = */ ggml_backend_mpi_supports_op, + }; + + auto *mpi_backend = new ggml_backend { + /* .guid = */ &backend_mpi_guid, + /* .interface = */ mpi_backend_i, + /* .context = */ ctx, + }; + + cached_backends[wrapped_backends] = mpi_backend; + + return mpi_backend; +} + +static ggml_backend_t ggml_backend_reg_mpi_init(const char * params, void * user_data) { + // TODO check what the parameters are for. Could use it to setup the MPI comms and routes? + GGML_UNUSED(params); + ggml_mpi_backend_init(); + auto * v = new std::vector(); + v->push_back(ggml_backend_cpu_init()); + return ggml_backend_mpi_init(v->data(), 1, 0); +} + + + + +extern "C" GGML_CALL int ggml_backend_mpi_reg_devices(); + +int ggml_backend_mpi_reg_devices() { + auto devices = ggml_mpi_available_devices_internal(); + for (const auto & device : devices) { + ggml_backend_register( + device.name, + ggml_backend_reg_mpi_init, + ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type()), + reinterpret_cast(intptr_t(device.index)) + ); + } + return devices.size(); +} + + + +GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t buft, int rank) { + if (buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { + ((ggml_backend_mpi_buffer_type_context *) buft->context)->ctx_mpi->rank = rank; + } else { + GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type"); + } +} + +GGML_CALL void ggml_backend_mpi_buffer_set_rank(ggml_backend_buffer_t buf, int rank) { + if (buf->iface.get_name == ggml_backend_mpi_buffer_name) { + ((ggml_backend_mpi_buffer_context *) buf->context)->ctx_mpi->rank = rank; + ggml_backend_mpi_buffer_type_set_rank(buf->buft, rank); + } else { + GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type"); + } +} + + diff --git a/ggml-mpi.h b/ggml-mpi.h index eda119d44..d988a81e4 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -1,4 +1,8 @@ #pragma once +#include +#include +#include "ggml.h" +#include "ggml-backend.h" struct ggml_context; struct ggml_tensor; @@ -8,31 +12,238 @@ struct ggml_cgraph; extern "C" { #endif +#define GGML_MPI_DECODE 0 + +#define GGML_MPI_KV_CLEAR 1 + +#define GGML_MPI_KV_SEQ_RM 2 + +#define GGML_MPI_KV_SEQ_CP 3 + +#define GGML_MPI_KV_SEQ_KEEP 4 + +#define GGML_MPI_KV_SEQ_ADD 5 + +#define GGML_MPI_SHUTDOWN 6 + +#define GGML_MPI_TRANSFER_TENSORS 7 + +#define GGML_MPI_SYNC_LOGITS 8 + +#define GGML_MPI_CANCEL_RUN 9 + +#define GGML_MPI_KV_SEQ_CP_BACK 10 + +#define GGML_MPI_TRANS_ID 11 + +#define GGML_MPI_BATCH_ID 12 + +#define GGML_MPI_N_TOKENS 13 + +#define GGML_MPI_TOKENS 14 + +#define GGML_MPI_N_SEQ_IDS 15 + +#define GGML_MPI_SEQ_IDS 16 + +#define GGML_MPI_POS 17 + +#define GGML_MPI_BEGIN_TRANSACTION 18 + +#define GGML_MPI_MAX_N_SEQ 19 + +#define GGML_MPI_BATCH_LOGITS 20 + +#define GGML_MPI_KV_SEQ_DIV 21 + + + +/** + * The context used for MPI operations, + * a program may make use of more than one + * context but must always have at least one. + * + * The context stores required information like the + * node rank and a communicator to use for MPI operations. + * A context is guaranteed to be internally consistent, + * meaning that a context's stored rank is valid within + * the context's communicator. + */ struct ggml_mpi_context; + +/** + * Initialize the MPI library and the GGML MPI backend. + * Calling more than once during the lifetime of the program + * leads to undefined behavior. This function must be called before + * any MPI operations. + */ void ggml_mpi_backend_init(void); + +/** + * Frees the MPI backend, must be called only once at termination + * of the program. No MPI operations may be completed after calling this function, + * and attempting to do so will lead to undefined behavior. + */ void ggml_mpi_backend_free(void); +/** + * Construct a new MPI context using the MPI_WORLD + * communicator. This is useful only to create the + * initial context, as calling multiple times + * will only create effective copies of the same data. + * + * @return A context for us in the global communicator. + */ struct ggml_mpi_context * ggml_mpi_init(void); + +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_backend_buffer_type_t buft); + +GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf); + + +void ggml_mpi_sync_ints_pipelined( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +); + +void ggml_mpi_sync_ints_pipelined_back( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +); +// clear = 1, rm = 2, cp = 3, keep = 4, seq_shift = 5 +void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag); +int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi); + +int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag); +int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi); + +/** + * Create a new context by splitting the given context's + * communicator, creating a "sub-communicator." This is a collective + * operation and must be performed by all nodes within the same communicator. + * The color and key have the same meaning as in MPI_Comm_split(), i.e. + * the color is used to determine the sub-communicator this node will belong to, + * and the key is the relative rank of this node in the new communicator. + * + * An example: if a node passes a color of 1, and a different node passes a color of 2, + * the nodes will belong to two different sub-communicators. If two nodes pass the same + * color, then their ranks will be ordered by the order of their keys. If they pass the same + * key, then the tie will be broken by the nodes' ranks in the old communicator. + * + * The communicator used by the given context remains entirely valid, so it is advisable + * to store both the old and new contexts. This allows an application to + * select at runtime which communicator to perform MPI operations with. An example + * would be to segregate the nodes into multiple domains categorized by the functions + * they perform, and use the original context to broadcast to all nodes in the cluster. + * + * @param ctx The context containing the communicator to split. + * @param color The sub-communicator that this node will belong to. + * @param key The relative rank of this node in the new communicator. + * @return A new context with all values referencing the newly-created communicator. + */ +struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int color, int key); + +/** + * Frees the given context, including the communicator. No MPI + * operations besides ggml_mpi_backend_freee(void) should be executed after + * running this function. + * + * @param ctx The context to free. + */ void ggml_mpi_free(struct ggml_mpi_context * ctx); +/** + * Get the rank of this node in the given context's communicator. + * + * @param ctx The context to use to determine the rank with regards to. + * @return The rank of this node. + */ int ggml_mpi_rank(struct ggml_mpi_context * ctx); +/** + * Get the number of nodes that are a part of + * the communicator referenced by the given context. + * + * @param ctx The context containing the communicator used for this size check. + * @return The number of nodes that are a part of the given context's communicator. + */ +size_t ggml_mpi_size(struct ggml_mpi_context * ctx); + +/** + * Synchronize needed information among the nodes + * to prepare for running an evaluation iteration. + * This is a collective operation and all nodes must + * call this function. It will block until all + * nodes have entered it, to prevent any desync + * between nodes. + * + * @param ctx_mpi The context in which to prepare for evaluation. + * @param n_tokens A pointer to the n_tokens, which will be synchronized after this function. + * @param pos A pointer to the pos array, which will be synchronized after this function. + * @param n_seq_ids A pointer to the n_seq_ids array, which will be synchronized after this function. + * @param seq_id A pointer to the seq_id 2D array, which will be synchronized after this function. + * @param logits A pointer to the logits array, which is unused currently since only node 0 needs them. + */ void ggml_mpi_eval_init( - struct ggml_mpi_context * ctx_mpi, - int * n_tokens, - int * n_past, - int * n_threads); + struct ggml_mpi_context * ctx_mpi, + int32_t * n_tokens, + int32_t ** pos, + int32_t ** n_seq_ids, + int32_t *** seq_id, + int8_t ** logits, + uint32_t n_seq_max); -void ggml_mpi_graph_compute_pre( - struct ggml_mpi_context * ctx_mpi, - struct ggml_cgraph * gf, - int n_layers); +void ggml_mpi_sync_int( + struct ggml_mpi_context * ctx_mpi, + int32_t * val + ); -void ggml_mpi_graph_compute_post( - struct ggml_mpi_context * ctx_mpi, - struct ggml_cgraph * gf, - int n_layers); +/** + * Split a range across all nodes within the given + * context, weighting the allocations by the given weights. + * The dimensions of the returned 2d array are (number of nodes in the context, 2). + * The first element in the inner array is the starting point of the range allocated + * to the node indicated by the index into the outer array, + * and the second element is the end point of the allocated range, inclusive. + * + * @param ctx_mpi The context used to determine the number of nodes + * to split the range across. + * @param start The starting point of the range. + * @param end The end point of the range, inclusive. + * @param node_weights How to weight the allocations across the nodes, + * must sum to 1.0. + * @return A 2d array, the first dimension is the number of nodes in the context + * and the second dimension is 2. + */ +uint16_t** ggml_mpi_split_range( + struct ggml_mpi_context * ctx_mpi, + uint16_t start, + uint16_t end, + const float node_weights[] +); + +// BACKEND V2 + +struct ggml_mpi_device { + int index; + struct ggml_mpi_context * ctx_mpi; + const char * name; + int subgroupSize; +}; + +#define MPI_BACKEND_NAME "MPI" +GGML_CALL int ggml_backend_mpi_reg_devices(); + +GGML_CALL ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends, int rank); + +GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t buft, int rank); + +GGML_CALL void ggml_backend_mpi_buffer_set_rank(ggml_backend_buffer_t buft, int rank); #ifdef __cplusplus } diff --git a/ggml.h b/ggml.h index c937d4a53..5d6c9e4ab 100644 --- a/ggml.h +++ b/ggml.h @@ -226,7 +226,7 @@ #define GGML_MAX_DIMS 4 #define GGML_MAX_PARAMS 2048 -#define GGML_MAX_CONTEXTS 64 +#define GGML_MAX_CONTEXTS 256 #define GGML_MAX_SRC 10 #ifndef GGML_MAX_NAME #define GGML_MAX_NAME 64 @@ -381,6 +381,7 @@ extern "C" { GGML_BACKEND_TYPE_CPU = 0, GGML_BACKEND_TYPE_GPU = 10, GGML_BACKEND_TYPE_GPU_SPLIT = 20, + GGML_BACKEND_TYPE_MPI_SPLIT = 30, }; // model file types diff --git a/llama.cpp b/llama.cpp index 91bd6b8d0..66414d2bb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1118,6 +1118,10 @@ struct llama_mmap { int flags = MAP_SHARED; // prefetch/readahead impairs performance on NUMA systems if (numa) { prefetch = 0; } + +#ifdef GGML_USE_MPI + prefetch = 0; +#endif #ifdef __linux__ // advise the kernel to read the file sequentially (increases readahead) if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) { @@ -1126,6 +1130,7 @@ struct llama_mmap { } if (prefetch) { flags |= MAP_POPULATE; } #endif + addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); if (addr == MAP_FAILED) { // NOLINT throw std::runtime_error(format("mmap failed: %s", strerror(errno))); @@ -1487,6 +1492,10 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer if (buft == nullptr) { buft = ggml_backend_cpu_buffer_type(); } + +#if defined(GGML_USE_MPI) + buft = ggml_backend_mpi_wrap_buffer_type(buft); +#endif return buft; GGML_UNUSED(host_buffer); @@ -1538,6 +1547,8 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_g if (buft == nullptr) { buft = llama_default_buffer_type_offload(fallback_gpu); } + + return buft; GGML_UNUSED(tensor_split); @@ -1761,6 +1772,7 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + uint32_t n_seq_max; }; struct llama_layer { @@ -2036,6 +2048,10 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; +#ifdef GGML_USE_MPI + ggml_mpi_context * ctx_mpi = nullptr; +#endif + ~llama_model() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); @@ -2139,12 +2155,9 @@ struct llama_context { struct ggml_tensor * inp_s_mask; // F32 [1, kv_size] struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] + // control vectors struct llama_control_vector cvec; - -#ifdef GGML_USE_MPI - ggml_mpi_context * ctx_mpi = NULL; -#endif }; // @@ -2218,7 +2231,7 @@ static bool llama_kv_cache_init( }; ggml_context * ctx = ggml_init(params); if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to allocate context for kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to allocate context for kv cache, n_layers=%d\n", __func__, n_layers); return false; } ctx_map[it.first] = ctx; @@ -3321,6 +3334,11 @@ static void llm_load_hparams( auto & hparams = model.hparams; const gguf_context * ctx = ml.ctx_gguf; +#ifdef GGML_USE_MPI + model.ctx_mpi = ggml_mpi_init(); + +#endif + // get metadata as string for (int i = 0; i < gguf_get_n_kv(ctx); i++) { enum gguf_type type = gguf_get_kv_type(ctx, i); @@ -4062,6 +4080,7 @@ static bool llm_load_tensors( enum llama_split_mode split_mode, int main_gpu, const float * tensor_split, + const float * node_split, bool use_mlock, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -4150,6 +4169,32 @@ static bool llm_load_tensors( } } +#ifdef GGML_USE_MPI + uint16_t** ranges = ggml_mpi_split_range(model.ctx_mpi, 0, n_layer - 1, node_split); + + + size_t size = ggml_mpi_size(model.ctx_mpi); + + for (size_t i = 0; i < size; i++) { + for (uint16_t j = ranges[i][0]; j < ranges[i][1]; j++) { + printf("Setting buffer rank for i %zu and j %d\n", i, j); + ggml_backend_mpi_buffer_type_set_rank(model.buft_layer[j].buft, (int)i); + ggml_backend_mpi_buffer_type_set_rank(model.buft_layer[j].buft_matrix, (int)i); + } + } + + + // Will run with inputs on other nodes, but output may not be correct. + // Default is node 0 anyway, but better to be explicit about it + ggml_backend_mpi_buffer_type_set_rank(model.buft_input.buft, 0); + ggml_backend_mpi_buffer_type_set_rank(model.buft_input.buft_matrix, 0); + + + // Outputs *must* be on node 0, otherwise a deadlock occurs + ggml_backend_mpi_buffer_type_set_rank(model.buft_output.buft, 0); + ggml_backend_mpi_buffer_type_set_rank(model.buft_output.buft_matrix, 0); +#endif + // count used buffer types std::map buft_layer_count; buft_layer_count[model.buft_input.buft]++; @@ -5041,6 +5086,7 @@ static bool llm_load_tensors( size_t first, last; ml.get_mapping_range(&first, &last, ctx); buf = ggml_backend_cpu_buffer_from_ptr((char *) ml.mapping->addr + first, last - first); + #ifdef GGML_USE_CUBLAS if (n_layer >= n_gpu_layers) { ggml_backend_cuda_register_host_buffer( @@ -5066,9 +5112,14 @@ static bool llm_load_tensors( mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); } } + if (buf == nullptr) { throw std::runtime_error("failed to allocate buffer"); } + + #ifdef GGML_USE_MPI + buf = ggml_backend_mpi_wrap_buffer(buf); + #endif // indicate that this buffer contains weights // this is used by ggml_backend_sched to improve op scheduling -> ops that use a weight are preferably scheduled to the backend that contains the weight ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); @@ -5181,7 +5232,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam #endif if (!llm_load_tensors( - ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock, + ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.node_layer_weights, params.use_mlock, params.progress_callback, params.progress_callback_user_data )) { return -2; @@ -5840,6 +5891,7 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * inpSA = inpL; // norm @@ -6024,6 +6076,7 @@ struct llm_build_context { struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * inpSA = inpL; cur = llm_build_norm(ctx0, inpL, hparams, @@ -6132,6 +6185,7 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * attn_norm; attn_norm = llm_build_norm(ctx0, inpL, hparams, @@ -6250,6 +6304,7 @@ struct llm_build_context { cb(inpL, "inpL", -1); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, @@ -6337,6 +6392,7 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * residual = inpL; cur = llm_build_norm(ctx0, inpL, hparams, @@ -6536,6 +6592,7 @@ struct llm_build_context { struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * inpSA = inpL; cur = llm_build_norm(ctx0, inpL, hparams, @@ -6815,6 +6872,7 @@ struct llm_build_context { cb(inpL, "inp_norm", -1); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, @@ -6902,6 +6960,7 @@ struct llm_build_context { struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * attn_norm; attn_norm = llm_build_norm(ctx0, inpL, hparams, @@ -8975,10 +9034,7 @@ static void llama_graph_compute( llama_context & lctx, ggml_cgraph * gf, int n_threads) { -#ifdef GGML_USE_MPI - const int64_t n_layer = lctx.model.hparams.n_layer; - ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); -#endif + #ifdef GGML_USE_METAL if (ggml_backend_is_metal(lctx.backend_metal)) { @@ -8995,9 +9051,6 @@ static void llama_graph_compute( // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); -#ifdef GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); -#endif } // decode a batch of tokens by evaluating the transformer @@ -9011,9 +9064,62 @@ static void llama_graph_compute( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch_all) { // TODO: rename back to batch + llama_batch & batch_all) { // TODO: rename back to batch - const uint32_t n_tokens_all = batch_all.n_tokens; + +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(lctx.model.ctx_mpi) == 0 && ggml_mpi_size(lctx.model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_DECODE; + ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } +// ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.batch_id, 1, GGML_MPI_BATCH_ID); + int old_tokens = batch_all.n_tokens; + + ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.n_tokens, 1, GGML_MPI_N_TOKENS); + + ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, reinterpret_cast(&lctx.cparams.n_seq_max), 1, GGML_MPI_MAX_N_SEQ); + if (ggml_mpi_rank(lctx.model.ctx_mpi) > 0) { + int new_n_tokens = batch_all.n_tokens; + llama_batch_free(batch_all); + batch_all = llama_batch_init(new_n_tokens, 0, (int32_t)lctx.cparams.n_seq_max); + } +#endif + + uint32_t n_tokens_all = batch_all.n_tokens; + + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; + + if (batch_all.pos == nullptr) { + pos.resize(n_tokens_all); + for (uint32_t i = 0; i < n_tokens_all; i++) { + pos[i] = batch_all.all_pos_0 + i*batch_all.all_pos_1; + } + + batch_all.pos = pos.data(); + } + + if (batch_all.seq_id == nullptr) { + n_seq_id.resize(n_tokens_all); + seq_id.resize(n_tokens_all); + seq_id_arr.resize(n_tokens_all); + for (uint32_t i = 0; i < n_tokens_all; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(lctx.cparams.n_seq_max); + seq_id[i][0] = batch_all.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } + + batch_all.n_seq_id = n_seq_id.data(); + batch_all.seq_id = seq_id_arr.data(); + } + +#ifdef GGML_USE_MPI + ggml_mpi_eval_init(lctx.model.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max); + n_tokens_all = batch_all.n_tokens; +#endif if (n_tokens_all == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); @@ -9035,11 +9141,7 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; -#ifdef GGML_USE_MPI - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); -#endif + auto & kv_self = lctx.kv_self; @@ -9059,13 +9161,9 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); + uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); llama_batch u_batch = { /* .n_tokens = */ (int32_t) n_tokens, /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, @@ -9099,7 +9197,7 @@ static int llama_decode_internal( seq_id_arr.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { n_seq_id[i] = 1; - seq_id[i].resize(1); + seq_id[i].resize(lctx.cparams.n_seq_max); seq_id[i][0] = u_batch.all_seq_id; seq_id_arr[i] = seq_id[i].data(); } @@ -9200,11 +9298,17 @@ static int llama_decode_internal( // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(lctx.model.ctx_mpi) == 0) { +#endif + // extract logits // TODO: do not compute and extract logits if only embeddings are needed // update the graphs to skip "result_output" if logits are not needed if (res) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); + + + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); if (u_batch.logits) { int32_t i_first = -1; @@ -9288,6 +9392,10 @@ static int llama_decode_internal( } break; } } + +#ifdef GGML_USE_MPI + } +#endif } // wait for the computation to finish (automatically done when obtaining the model output) @@ -9305,6 +9413,8 @@ static int llama_decode_internal( } } + + return 0; } @@ -12824,6 +12934,7 @@ static int llama_apply_lora_from_file_internal( // struct llama_model_params llama_model_default_params() { struct llama_model_params result = { + static_cast(calloc(1, sizeof(float))), /*.n_gpu_layers =*/ 0, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, @@ -12891,6 +13002,14 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { return result; } +int llama_node_id(struct llama_context * ctx) { +#ifdef GGML_USE_MPI + return ggml_mpi_rank(ctx->model.ctx_mpi); + +#endif + return 0; +} + size_t llama_max_devices(void) { #if defined(GGML_USE_METAL) return 1; @@ -12936,6 +13055,7 @@ void llama_backend_init(void) { #ifdef GGML_USE_MPI ggml_mpi_backend_init(); #endif + } void llama_numa_init(enum ggml_numa_strategy numa) { @@ -13022,7 +13142,7 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - // TODO: maybe add n_seq_max here too + cparams.n_seq_max = params.n_seq_max; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -13192,6 +13312,8 @@ struct llama_context * llama_new_context_with_model( ctx->backends.push_back(backend); } #endif + + ctx->backend_cpu = ggml_backend_cpu_init(); if (ctx->backend_cpu == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); @@ -13200,6 +13322,23 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); +#ifdef GGML_USE_MPI + + std::vector new_backends; + + for (size_t i = 0; i < ggml_mpi_size(model->ctx_mpi); i++) { + new_backends.push_back(ggml_backend_mpi_init(ctx->backends.data(), ctx->backends.size(), (int) i)); + } + + ctx->backends = new_backends; + + + +// ctx->backend_cpu = ctx->backends.back(); + ctx->backends.push_back(ggml_backend_mpi_init(&ctx->backend_cpu, 1, ggml_mpi_rank(model->ctx_mpi))); + +#endif + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); @@ -13311,23 +13450,15 @@ struct llama_context * llama_new_context_with_model( } } -#ifdef GGML_USE_MPI - ctx->ctx_mpi = ggml_mpi_init(); - if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { - // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); - //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; - llama_backend_free(); - exit(1); - } -#endif return ctx; } +void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights) { + +} + void llama_free(struct llama_context * ctx) { delete ctx; } @@ -13711,14 +13842,56 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { } void llama_kv_cache_clear(struct llama_context * ctx) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_CLEAR; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, nullptr, 0, GGML_MPI_KV_CLEAR); +#endif llama_kv_cache_clear(ctx->kv_self); } bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_RM; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[3] = {seq_id, p0, p1}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 3, GGML_MPI_KV_SEQ_RM); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1); +// } +#endif return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_CP; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + + int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP); +// if(ggml_mpi_recv_trans_id(ctx->model.ctx_mpi) < ggml_mpi_trans_id(ctx->model.ctx_mpi)) { +//// return; +// } +// ggml_mpi_inc_trans_id(ctx->model.ctx_mpi); + seq_id_src = vals[0]; + seq_id_dst = vals[1]; + p0 = vals[2]; + p1 = vals[3]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nCopying sequence %d to sequence %d from %d to %d\n", seq_id_src, seq_id_dst, p0, p1); +// } +#endif + if (seq_id_src == seq_id_dst) { return; } @@ -13726,10 +13899,35 @@ void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, } void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_KEEP; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[1] = {seq_id}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 1, GGML_MPI_KV_SEQ_KEEP); + seq_id = vals[0]; +#endif llama_kv_cache_seq_keep(ctx->kv_self, seq_id); } void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_ADD; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[4] = {seq_id, p0, p1, delta}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_ADD); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; + delta = vals[3]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1); +// } +#endif + if (delta == 0) { return; } @@ -13738,6 +13936,22 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla } void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_DIV; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[4] = {seq_id, p0, p1, d}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_DIV); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; + d = vals[3]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1); +// } +#endif + if (d == 1) { return; } @@ -14244,11 +14458,111 @@ void llama_batch_free(struct llama_batch batch) { free(batch.seq_id); } if (batch.logits) free(batch.logits); + + batch.token = nullptr; + batch.embd = nullptr; + batch.pos = nullptr; + batch.n_seq_id = nullptr; + batch.seq_id = nullptr; + batch.logits = nullptr; } +#ifdef GGML_USE_MPI + +int llama_process_mpi_transaction( + struct llama_context * ctx, + struct llama_batch & batch, + int tag) { +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) { +// printf("\nBeginning transaction type %d\n", tag); +// } + + switch (tag) { + case GGML_MPI_DECODE: +// llama_batch_free(batch); + return llama_decode_internal(*ctx, batch); + break; + case GGML_MPI_KV_CLEAR: + llama_kv_cache_clear(ctx); + break; + case GGML_MPI_KV_SEQ_RM: + llama_kv_cache_seq_rm(ctx, 1, -1, -1); + break; + case GGML_MPI_KV_SEQ_CP: + llama_kv_cache_seq_cp(ctx, 0, 0, 0, 0); + break; +// case GGML_MPI_KV_SEQ_CP_BACK: +// llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0); +// break; + case GGML_MPI_KV_SEQ_KEEP: + llama_kv_cache_seq_keep(ctx, 0); + break; + case GGML_MPI_KV_SEQ_ADD: + llama_kv_cache_seq_add(ctx, 0, 0, 0, 0); + break; + case GGML_MPI_KV_SEQ_DIV: + llama_kv_cache_seq_div(ctx, 0, 0, 0, 0); + break; + default: + printf("Unknown operation, exiting\n"); + exit(1); + break; + } + return 0; +} + +int llama_process_mpi_worker( + struct llama_context * ctx, + struct llama_batch & batch) { + ggml_mpi_probe(ctx->model.ctx_mpi, -1, -1); + int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi); + int32_t count; + int32_t trans_type; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) { +// printf("\nReceived command %d\n", tag); +// } + switch (tag) { + case GGML_MPI_BEGIN_TRANSACTION: + + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &trans_type, 1, GGML_MPI_BEGIN_TRANSACTION); + return llama_process_mpi_transaction(ctx, batch, trans_type); + break; + case GGML_MPI_SHUTDOWN: + llama_free(ctx); + llama_backend_free(); + exit(0); + break; + case GGML_MPI_CANCEL_RUN: +// count = ggml_mpi_status_count_int32(ctx->model.ctx_mpi); +//// printf("Received cancel run\n"); +// { +// std::vector canceled(count, -1); +// llama_cancel_run(ctx, canceled.data(), canceled.size()); +// +// } +// break; + default: + printf("Unknown operation, exiting\n"); + exit(1); + break; + } + return 0; +} + +#endif + int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { + +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + while (llama_process_mpi_worker(ctx, batch) >= 0){}; + llama_backend_free(); + exit(1); + } +#endif const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); diff --git a/llama.h b/llama.h index 40dcf54e3..96988fd09 100644 --- a/llama.h +++ b/llama.h @@ -8,7 +8,6 @@ #include #include #include - #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) # ifdef LLAMA_BUILD @@ -203,6 +202,9 @@ extern "C" { }; struct llama_model_params { + // Array of layers to allocate to each node + const float * node_layer_weights; + int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -358,6 +360,8 @@ extern "C" { const char * path_model, struct llama_model_params params); + LLAMA_API void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights); + LLAMA_API void llama_free_model(struct llama_model * model); LLAMA_API struct llama_context * llama_new_context_with_model( @@ -371,6 +375,9 @@ extern "C" { LLAMA_API size_t llama_max_devices(void); + // Get the ID of this compute node, usually 0 + // unless running MPI, in which case it is the rank of the node + LLAMA_API int llama_node_id(struct llama_context * ctx); LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_gpu_offload(void);