From eaef2d0e76d8a64b80154bc0c253408a659928e3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Jul 2023 18:47:24 +0300 Subject: [PATCH] mpi : extend API to allow usage with outer backends (e.g. Metal) --- .gitignore | 1 + ggml-metal.m | 1 + ggml-mpi.c | 16 ++++++++------ ggml-mpi.h | 11 ++++++---- llama.cpp | 59 +++++++++++++++++++++++++--------------------------- 5 files changed, 47 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index 4fccec31b..faec869e0 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ build-static/ build-cublas/ build-opencl/ build-metal/ +build-mpi/ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ diff --git a/ggml-metal.m b/ggml-metal.m index 3f15f791f..6473644c2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -450,6 +450,7 @@ void ggml_metal_graph_compute( //} switch (dst->op) { + case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: diff --git a/ggml-mpi.c b/ggml-mpi.c index 6282b7276..872e808de 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -102,12 +102,10 @@ static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { } // TODO: there are many improvements that can be done to this implementation -void ggml_mpi_graph_compute( +void ggml_mpi_graph_compute_pre( struct ggml_mpi_context * ctx_mpi, - struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers, - int n_threads) { + int n_layers) { const int mpi_rank = ctx_mpi->rank; const int mpi_size = ctx_mpi->size; @@ -200,10 +198,16 @@ void ggml_mpi_graph_compute( //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); } +} - ggml_graph_compute_with_ctx(ctx, gf, n_threads); +void ggml_mpi_graph_compute_post( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers) { + UNUSED(n_layers); - //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); + const int mpi_rank = ctx_mpi->rank; + const int mpi_size = ctx_mpi->size; // send the output data to the next node if (mpi_rank > 0) { diff --git a/ggml-mpi.h b/ggml-mpi.h index 2ad0a4386..eda119d44 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -24,12 +24,15 @@ void ggml_mpi_eval_init( int * n_past, int * n_threads); -void ggml_mpi_graph_compute( +void ggml_mpi_graph_compute_pre( struct ggml_mpi_context * ctx_mpi, - struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers, - int n_threads); + int n_layers); + +void ggml_mpi_graph_compute_post( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 322e37a7d..ad7283faf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1632,6 +1632,10 @@ static bool llama_eval_internal( // run the computation ggml_build_forward_expand(&gf, cur); +#if GGML_USE_MPI + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer); +#endif + #ifdef GGML_USE_METAL if (lctx.ctx_metal && N == 1) { ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); @@ -1656,14 +1660,19 @@ static bool llama_eval_internal( ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); } -#elif GGML_USE_MPI - ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads); - - cur = gf.nodes[gf.n_nodes - 1]; #else ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); #endif +#if GGML_USE_MPI + ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer); +#endif + + // update kv token count + lctx.kv_self.n = n_past + N; + + struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1]; + if (cgraph_fname) { ggml_graph_export(&gf, cgraph_fname); } @@ -1679,38 +1688,26 @@ static bool llama_eval_internal( // ggml_graph_dump_dot(&gf, NULL, "llama.dot"); //} - //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N); - - // update kv token count - lctx.kv_self.n = n_past + N; - -#ifdef GGML_USE_MPI - if (ggml_mpi_rank(lctx.ctx_mpi) == 0) { -#else + // extract logits { -#endif - // extract logits - { - auto & logits_out = lctx.logits; + auto & logits_out = lctx.logits; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); - } + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); } + } - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + // extract embeddings + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; - embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); - } + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); } if (mem_per_token == 0) {