mpi : extend API to allow usage with outer backends (e.g. Metal)

This commit is contained in:
Georgi Gerganov 2023-07-10 18:47:24 +03:00
parent c3c3ef11a6
commit eaef2d0e76
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 47 additions and 41 deletions

1
.gitignore vendored
View file

@ -20,6 +20,7 @@ build-static/
build-cublas/ build-cublas/
build-opencl/ build-opencl/
build-metal/ build-metal/
build-mpi/
build-no-accel/ build-no-accel/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/

View file

@ -450,6 +450,7 @@ void ggml_metal_graph_compute(
//} //}
switch (dst->op) { switch (dst->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:

View file

@ -102,12 +102,10 @@ static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) {
} }
// TODO: there are many improvements that can be done to this implementation // TODO: there are many improvements that can be done to this implementation
void ggml_mpi_graph_compute( void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
int n_layers, int n_layers) {
int n_threads) {
const int mpi_rank = ctx_mpi->rank; const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size; const int mpi_size = ctx_mpi->size;
@ -200,10 +198,16 @@ void ggml_mpi_graph_compute(
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
} }
}
ggml_graph_compute_with_ctx(ctx, gf, n_threads); void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers) {
UNUSED(n_layers);
//fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
// send the output data to the next node // send the output data to the next node
if (mpi_rank > 0) { if (mpi_rank > 0) {

View file

@ -24,12 +24,15 @@ void ggml_mpi_eval_init(
int * n_past, int * n_past,
int * n_threads); int * n_threads);
void ggml_mpi_graph_compute( void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
int n_layers, int n_layers);
int n_threads);
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers);
#ifdef __cplusplus #ifdef __cplusplus
} }

View file

@ -1632,6 +1632,10 @@ static bool llama_eval_internal(
// run the computation // run the computation
ggml_build_forward_expand(&gf, cur); ggml_build_forward_expand(&gf, cur);
#if GGML_USE_MPI
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
#endif
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) { if (lctx.ctx_metal && N == 1) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
@ -1656,14 +1660,19 @@ static bool llama_eval_internal(
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
} }
#elif GGML_USE_MPI
ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads);
cur = gf.nodes[gf.n_nodes - 1];
#else #else
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
#endif #endif
#if GGML_USE_MPI
ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer);
#endif
// update kv token count
lctx.kv_self.n = n_past + N;
struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1];
if (cgraph_fname) { if (cgraph_fname) {
ggml_graph_export(&gf, cgraph_fname); ggml_graph_export(&gf, cgraph_fname);
} }
@ -1679,38 +1688,26 @@ static bool llama_eval_internal(
// ggml_graph_dump_dot(&gf, NULL, "llama.dot"); // ggml_graph_dump_dot(&gf, NULL, "llama.dot");
//} //}
//embd_w.resize(n_vocab*N); // extract logits
//memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N);
// update kv token count
lctx.kv_self.n = n_past + N;
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(lctx.ctx_mpi) == 0) {
#else
{ {
#endif auto & logits_out = lctx.logits;
// extract logits
{
auto & logits_out = lctx.logits;
if (lctx.logits_all) { if (lctx.logits_all) {
logits_out.resize(n_vocab * N); logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
} else { } else {
// return result for just the last token // return result for just the last token
logits_out.resize(n_vocab); logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
} }
}
// extract embeddings // extract embeddings
if (!lctx.embedding.empty()) { if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding; auto & embedding_out = lctx.embedding;
embedding_out.resize(n_embd); embedding_out.resize(n_embd);
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}
} }
if (mem_per_token == 0) { if (mem_per_token == 0) {