Fix some mpi mem leaks, add mpi-layer-split to help when using mpi

This commit is contained in:
Branden Butler 2023-10-31 15:55:15 -05:00
parent 888d4f591b
commit b7599f7a56
4 changed files with 38 additions and 21 deletions

View file

@ -1094,6 +1094,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
}
#ifdef GGML_USE_MPI
printf(" --mpi-layer-split N percentiles to split the layers by across nodes\n");
#endif
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
printf(" -gan N, --grp-attn-n N\n");

View file

@ -47,7 +47,7 @@ struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int
}
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
MPI_Comm_free(ctx->comm);
MPI_Comm_free(&(ctx->comm));
free(ctx);
}
@ -55,7 +55,7 @@ int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
return ctx->rank;
}
int ggml_mpi_size(struct ggml_mpi_context * ctx) {
size_t ggml_mpi_size(struct ggml_mpi_context * ctx) {
return ctx->size;
}
@ -69,30 +69,41 @@ void ggml_mpi_eval_init(
MPI_Barrier(ctx_mpi->comm);
int32_t old_n_tokens = *n_tokens;
MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm);
if (ctx_mpi->rank != 0) {
*pos = calloc(*n_tokens, sizeof(int32_t));
*n_seq_ids = calloc(*n_tokens, sizeof(int32_t));
*logits = calloc(*n_tokens, sizeof(int8_t));
// If what was passed in differs from what was broadcast,
// we can't guarantee the allocated sizes are correct
// TODO check how often this is done and if it's a problem,
// try to allocate ahead of time
if (old_n_tokens != *n_tokens) {
*pos = realloc(*pos, *n_tokens * sizeof(int32_t));
*n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
*logits = realloc(*logits, *n_tokens * sizeof(int32_t));
}
// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
// We need to know the total number of sequence
// ids, so we count them all up
int32_t total_n_seq_ids = 0;
for (size_t i = 0; i < *n_tokens; i++) {
for (int32_t i = 0; i < *n_tokens; i++) {
total_n_seq_ids += (*n_seq_ids)[i];
}
MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first
// for transit
int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t));
int32_t current_index = 0;
// Only rank 0 needs to flatten since the others don't have the real seq_id
if (ctx_mpi->rank == 0) {
for (size_t i = 0; i < *n_tokens; i++) {
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
for (int32_t i = 0; i < *n_tokens; i++) {
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
flattened_seq_ids[current_index] = (*seq_id)[i][j];
current_index++;
}
@ -100,25 +111,26 @@ void ggml_mpi_eval_init(
}
MPI_Bcast(*pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*));
current_index = 0;
for (size_t i = 0; i < *n_tokens; i++) {
for (int32_t i = 0; i < *n_tokens; i++) {
new_seq_id[i] = calloc((*n_seq_ids)[i], sizeof(int32_t));
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
new_seq_id[i][j] = flattened_seq_ids[current_index];
current_index++;
}
}
free(flattened_seq_ids);
//free(*seq_id); // <- something is still holding onto this, need to investigate
*seq_id = new_seq_id;
}
void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi,
int32_t * val
struct ggml_mpi_context * ctx_mpi,
int32_t * val
) {
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
}
@ -284,7 +296,7 @@ void ggml_mpi_graph_compute_pre(
{
const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
//const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1;

View file

@ -1,5 +1,6 @@
#pragma once
#include <stdint.h>
#include <stddef.h>
struct ggml_context;
struct ggml_tensor;
@ -98,7 +99,7 @@ int ggml_mpi_rank(struct ggml_mpi_context * ctx);
* @param ctx The context containing the communicator used for this size check.
* @return The number of nodes that are a part of the given context's communicator.
*/
int ggml_mpi_size(struct ggml_mpi_context * ctx);
size_t ggml_mpi_size(struct ggml_mpi_context * ctx);
/**
* Synchronize needed information among the nodes

View file

@ -13094,6 +13094,7 @@ void llama_split_layers_weighted(struct llama_context * ctx, float device_weight
}
uint16_t** ranges = ggml_mpi_split_range(ctx->ctx_mpi, 0, ctx->model.hparams.n_layer - 1, device_weights);
ggml_mpi_scatter_layers(ctx->ctx_mpi, ranges);
free(ranges);
#endif
}