Fix some mpi mem leaks, add mpi-layer-split to help when using mpi
This commit is contained in:
parent
888d4f591b
commit
b7599f7a56
4 changed files with 38 additions and 21 deletions
|
@ -1094,6 +1094,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
|
||||
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
|
||||
}
|
||||
#ifdef GGML_USE_MPI
|
||||
printf(" --mpi-layer-split N percentiles to split the layers by across nodes\n");
|
||||
#endif
|
||||
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
|
||||
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
|
||||
printf(" -gan N, --grp-attn-n N\n");
|
||||
|
|
52
ggml-mpi.c
52
ggml-mpi.c
|
@ -47,7 +47,7 @@ struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int
|
|||
}
|
||||
|
||||
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
|
||||
MPI_Comm_free(ctx->comm);
|
||||
MPI_Comm_free(&(ctx->comm));
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
|
|||
return ctx->rank;
|
||||
}
|
||||
|
||||
int ggml_mpi_size(struct ggml_mpi_context * ctx) {
|
||||
size_t ggml_mpi_size(struct ggml_mpi_context * ctx) {
|
||||
return ctx->size;
|
||||
}
|
||||
|
||||
|
@ -69,30 +69,41 @@ void ggml_mpi_eval_init(
|
|||
|
||||
|
||||
MPI_Barrier(ctx_mpi->comm);
|
||||
|
||||
int32_t old_n_tokens = *n_tokens;
|
||||
MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm);
|
||||
|
||||
if (ctx_mpi->rank != 0) {
|
||||
*pos = calloc(*n_tokens, sizeof(int32_t));
|
||||
*n_seq_ids = calloc(*n_tokens, sizeof(int32_t));
|
||||
*logits = calloc(*n_tokens, sizeof(int8_t));
|
||||
// If what was passed in differs from what was broadcast,
|
||||
// we can't guarantee the allocated sizes are correct
|
||||
// TODO check how often this is done and if it's a problem,
|
||||
// try to allocate ahead of time
|
||||
if (old_n_tokens != *n_tokens) {
|
||||
*pos = realloc(*pos, *n_tokens * sizeof(int32_t));
|
||||
*n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
|
||||
*logits = realloc(*logits, *n_tokens * sizeof(int32_t));
|
||||
}
|
||||
|
||||
|
||||
|
||||
// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
|
||||
// We need to know the total number of sequence
|
||||
// ids, so we count them all up
|
||||
int32_t total_n_seq_ids = 0;
|
||||
for (size_t i = 0; i < *n_tokens; i++) {
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
total_n_seq_ids += (*n_seq_ids)[i];
|
||||
}
|
||||
|
||||
MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
|
||||
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first
|
||||
// for transit
|
||||
int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t));
|
||||
|
||||
int32_t current_index = 0;
|
||||
|
||||
// Only rank 0 needs to flatten since the others don't have the real seq_id
|
||||
if (ctx_mpi->rank == 0) {
|
||||
for (size_t i = 0; i < *n_tokens; i++) {
|
||||
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
||||
current_index++;
|
||||
}
|
||||
|
@ -100,25 +111,26 @@ void ggml_mpi_eval_init(
|
|||
}
|
||||
|
||||
|
||||
MPI_Bcast(*pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
|
||||
int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*));
|
||||
current_index = 0;
|
||||
for (size_t i = 0; i < *n_tokens; i++) {
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
new_seq_id[i] = calloc((*n_seq_ids)[i], sizeof(int32_t));
|
||||
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
new_seq_id[i][j] = flattened_seq_ids[current_index];
|
||||
current_index++;
|
||||
}
|
||||
}
|
||||
free(flattened_seq_ids);
|
||||
//free(*seq_id); // <- something is still holding onto this, need to investigate
|
||||
*seq_id = new_seq_id;
|
||||
}
|
||||
|
||||
void ggml_mpi_synch_int(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * val
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * val
|
||||
) {
|
||||
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
}
|
||||
|
@ -284,7 +296,7 @@ void ggml_mpi_graph_compute_pre(
|
|||
{
|
||||
|
||||
|
||||
const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
|
||||
//const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
|
||||
|
||||
const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
struct ggml_context;
|
||||
struct ggml_tensor;
|
||||
|
@ -98,7 +99,7 @@ int ggml_mpi_rank(struct ggml_mpi_context * ctx);
|
|||
* @param ctx The context containing the communicator used for this size check.
|
||||
* @return The number of nodes that are a part of the given context's communicator.
|
||||
*/
|
||||
int ggml_mpi_size(struct ggml_mpi_context * ctx);
|
||||
size_t ggml_mpi_size(struct ggml_mpi_context * ctx);
|
||||
|
||||
/**
|
||||
* Synchronize needed information among the nodes
|
||||
|
|
|
@ -13094,6 +13094,7 @@ void llama_split_layers_weighted(struct llama_context * ctx, float device_weight
|
|||
}
|
||||
uint16_t** ranges = ggml_mpi_split_range(ctx->ctx_mpi, 0, ctx->model.hparams.n_layer - 1, device_weights);
|
||||
ggml_mpi_scatter_layers(ctx->ctx_mpi, ranges);
|
||||
free(ranges);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue