Fix some mpi mem leaks, add mpi-layer-split to help when using mpi

This commit is contained in:
Branden Butler 2023-10-31 15:55:15 -05:00
parent 888d4f591b
commit b7599f7a56
4 changed files with 38 additions and 21 deletions

View file

@ -1094,6 +1094,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu); printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
} }
#ifdef GGML_USE_MPI
printf(" --mpi-layer-split N percentiles to split the layers by across nodes\n");
#endif
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false"); printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false"); printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
printf(" -gan N, --grp-attn-n N\n"); printf(" -gan N, --grp-attn-n N\n");

View file

@ -47,7 +47,7 @@ struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int
} }
void ggml_mpi_free(struct ggml_mpi_context * ctx) { void ggml_mpi_free(struct ggml_mpi_context * ctx) {
MPI_Comm_free(ctx->comm); MPI_Comm_free(&(ctx->comm));
free(ctx); free(ctx);
} }
@ -55,7 +55,7 @@ int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
return ctx->rank; return ctx->rank;
} }
int ggml_mpi_size(struct ggml_mpi_context * ctx) { size_t ggml_mpi_size(struct ggml_mpi_context * ctx) {
return ctx->size; return ctx->size;
} }
@ -69,30 +69,41 @@ void ggml_mpi_eval_init(
MPI_Barrier(ctx_mpi->comm); MPI_Barrier(ctx_mpi->comm);
int32_t old_n_tokens = *n_tokens;
MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm); MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm);
if (ctx_mpi->rank != 0) { // If what was passed in differs from what was broadcast,
*pos = calloc(*n_tokens, sizeof(int32_t)); // we can't guarantee the allocated sizes are correct
*n_seq_ids = calloc(*n_tokens, sizeof(int32_t)); // TODO check how often this is done and if it's a problem,
*logits = calloc(*n_tokens, sizeof(int8_t)); // try to allocate ahead of time
if (old_n_tokens != *n_tokens) {
*pos = realloc(*pos, *n_tokens * sizeof(int32_t));
*n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
*logits = realloc(*logits, *n_tokens * sizeof(int32_t));
} }
// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
// We need to know the total number of sequence
// ids, so we count them all up
int32_t total_n_seq_ids = 0; int32_t total_n_seq_ids = 0;
for (size_t i = 0; i < *n_tokens; i++) { for (int32_t i = 0; i < *n_tokens; i++) {
total_n_seq_ids += (*n_seq_ids)[i]; total_n_seq_ids += (*n_seq_ids)[i];
} }
MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm); // MPI can't chase the pointers for multidimensional arrays, so we flatten them first
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); // for transit
int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t)); int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t));
int32_t current_index = 0; int32_t current_index = 0;
// Only rank 0 needs to flatten since the others don't have the real seq_id
if (ctx_mpi->rank == 0) { if (ctx_mpi->rank == 0) {
for (size_t i = 0; i < *n_tokens; i++) { for (int32_t i = 0; i < *n_tokens; i++) {
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) { for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
flattened_seq_ids[current_index] = (*seq_id)[i][j]; flattened_seq_ids[current_index] = (*seq_id)[i][j];
current_index++; current_index++;
} }
@ -100,25 +111,26 @@ void ggml_mpi_eval_init(
} }
MPI_Bcast(*pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm); MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm); //MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*)); int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*));
current_index = 0; current_index = 0;
for (size_t i = 0; i < *n_tokens; i++) { for (int32_t i = 0; i < *n_tokens; i++) {
new_seq_id[i] = calloc((*n_seq_ids)[i], sizeof(int32_t)); new_seq_id[i] = calloc((*n_seq_ids)[i], sizeof(int32_t));
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) { for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
new_seq_id[i][j] = flattened_seq_ids[current_index]; new_seq_id[i][j] = flattened_seq_ids[current_index];
current_index++; current_index++;
} }
} }
free(flattened_seq_ids); free(flattened_seq_ids);
//free(*seq_id); // <- something is still holding onto this, need to investigate
*seq_id = new_seq_id; *seq_id = new_seq_id;
} }
void ggml_mpi_synch_int( void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,
int32_t * val int32_t * val
) { ) {
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm); MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
} }
@ -284,7 +296,7 @@ void ggml_mpi_graph_compute_pre(
{ {
const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size; //const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1; const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1;

View file

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <stddef.h>
struct ggml_context; struct ggml_context;
struct ggml_tensor; struct ggml_tensor;
@ -98,7 +99,7 @@ int ggml_mpi_rank(struct ggml_mpi_context * ctx);
* @param ctx The context containing the communicator used for this size check. * @param ctx The context containing the communicator used for this size check.
* @return The number of nodes that are a part of the given context's communicator. * @return The number of nodes that are a part of the given context's communicator.
*/ */
int ggml_mpi_size(struct ggml_mpi_context * ctx); size_t ggml_mpi_size(struct ggml_mpi_context * ctx);
/** /**
* Synchronize needed information among the nodes * Synchronize needed information among the nodes

View file

@ -13094,6 +13094,7 @@ void llama_split_layers_weighted(struct llama_context * ctx, float device_weight
} }
uint16_t** ranges = ggml_mpi_split_range(ctx->ctx_mpi, 0, ctx->model.hparams.n_layer - 1, device_weights); uint16_t** ranges = ggml_mpi_split_range(ctx->ctx_mpi, 0, ctx->model.hparams.n_layer - 1, device_weights);
ggml_mpi_scatter_layers(ctx->ctx_mpi, ranges); ggml_mpi_scatter_layers(ctx->ctx_mpi, ranges);
free(ranges);
#endif #endif
} }