diff --git a/ggml-mpi.c b/ggml-mpi.c index f1bbcabfe..6282b7276 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -56,7 +56,7 @@ void ggml_mpi_eval_init( MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); } -int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { +static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); if (t == NULL) { fprintf(stderr, "%s: tensor %s not found\n", __func__, name); @@ -73,6 +73,34 @@ int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { return -1; } +static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) { + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + MPI_Status status; UNUSED(status); + + const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); +} + // TODO: there are many improvements that can be done to this implementation void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, @@ -109,41 +137,19 @@ void ggml_mpi_graph_compute( // node 0: [(n-1) * n_per_node, n_nodes) // if (mpi_rank > 0) { - if (mpi_rank == 1) { // the first node receives the input tokens from the main node - MPI_Status status; UNUSED(status); - - const int mpi_rank_src = mpi_rank - 1; - - const int retval = MPI_Recv(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT32_T, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - GGML_ASSERT(retval == MPI_SUCCESS); - } else { // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) - MPI_Status status; UNUSED(status); - - const int mpi_rank_src = mpi_rank - 1; - - //printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); - const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - GGML_ASSERT(retval == MPI_SUCCESS); + if (mpi_rank == 1) { + // the first node (1) receives the input tokens from the main node (0) + ggml_mpi_tensor_recv(inp_tokens, 0); + } else { + // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) + ggml_mpi_tensor_recv(inp0, mpi_rank - 1); } } else if (mpi_size > 1) { // node 0 sends the input tokens to node 1 - { - const int mpi_rank_dst = mpi_rank + 1; - - const int retval = MPI_Send(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT32_T, mpi_rank_dst, 0, MPI_COMM_WORLD); - GGML_ASSERT(retval == MPI_SUCCESS); - } + ggml_mpi_tensor_send(inp_tokens, 1); // recv the output data from the last node - { - MPI_Status status; UNUSED(status); - - const int mpi_rank_src = mpi_size - 1; - - //fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); - const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - GGML_ASSERT(retval == MPI_SUCCESS); - } + ggml_mpi_tensor_recv(inp0, mpi_size - 1); } { @@ -201,13 +207,6 @@ void ggml_mpi_graph_compute( // send the output data to the next node if (mpi_rank > 0) { - struct ggml_tensor * output = gf->nodes[gf->n_nodes - 1]; - - const int mpi_rank_dst = (mpi_rank + 1) % mpi_size; - - //fprintf(stderr, "%s: node %d: sending %d elements to node %d\n", __func__, mpi_rank, ggml_nelements(output), mpi_rank_dst); - - const int retval = MPI_Send(output->data, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); - GGML_ASSERT(retval == MPI_SUCCESS); + ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size); } }