From 72dcd66c0fe23f7639728641790b96acd0ed575f Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Tue, 12 Mar 2024 14:20:03 -0500 Subject: [PATCH] Resize seq_ids by n_seq_max, port over sync_pipelined instead of using Bcast --- ggml-mpi.cpp | 200 +++++++++++++++++++++++++++++++++++++-------------- ggml-mpi.h | 47 +++++++++++- llama.cpp | 21 ++++-- 3 files changed, 207 insertions(+), 61 deletions(-) diff --git a/ggml-mpi.cpp b/ggml-mpi.cpp index 6ee5168ba..37e5f67c7 100644 --- a/ggml-mpi.cpp +++ b/ggml-mpi.cpp @@ -14,6 +14,10 @@ #define UNUSED GGML_UNUSED +static bool have_init = false; + +static void* send_buffer; + struct ggml_mpi_context { int rank; int size; @@ -26,18 +30,37 @@ struct ggml_mpi_context { std::vector backends; ggml_backend_sched_t scheduler; bool remote; + void* send_buffer; }; void ggml_mpi_backend_init(void) { int ret; - MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &ret); + + GGML_ASSERT(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ret) == MPI_SUCCESS); + have_init = true; + const int buffer_size = 128*1024*1024*8; + send_buffer = calloc(1, buffer_size); // 128MB buffer +// fprintf(stderr, "BUFFER ATTACH RETCODE=%d\n", MPI_Buffer_attach(send_buffer, buffer_size)); } +void ggml_mpi_sync_pipelined( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag +); + void ggml_mpi_backend_free(void) { MPI_Finalize(); } struct ggml_mpi_context * ggml_mpi_init(void) { + + if (!have_init) { + ggml_mpi_backend_init(); + } + auto * ctx = new ggml_mpi_context; MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); @@ -45,6 +68,8 @@ struct ggml_mpi_context * ggml_mpi_init(void) { ctx->comm = MPI_COMM_WORLD; ctx->remote = false; + ctx->send_buffer = send_buffer; + return ctx; } @@ -69,78 +94,147 @@ size_t ggml_mpi_size(struct ggml_mpi_context * ctx) { return ctx->size; } +int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) { + return (ctx_mpi->rank + 1) % ctx_mpi->size; +} + +int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi) { + int temp = (ctx_mpi->rank - 1); + return (temp >= 0) ? temp : ctx_mpi->size - 1; +} + +void ggml_mpi_sync_pipelined( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag +) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + +// printf("Rank %d sync pipelined with tag %d\n", ctx_mpi->rank, tag); + + + if (ctx_mpi->rank != 0) { + MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE); + } + if(ctx_mpi->rank < ctx_mpi->size - 1) { + GGML_ASSERT(ctx_mpi->send_buffer != nullptr); + const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm); + GGML_ASSERT(retval == MPI_SUCCESS); + + } +} + void ggml_mpi_eval_init( struct ggml_mpi_context * ctx_mpi, int32_t * n_tokens, int32_t ** pos, int32_t ** n_seq_ids, int32_t *** seq_id, - int8_t ** logits) { + int8_t ** logits, + uint32_t n_seq_max) { -// fprintf(stderr, "Beginning eval init on rank %d\n", ctx_mpi->rank); - MPI_Barrier(ctx_mpi->comm); + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } int32_t old_n_tokens = *n_tokens; - MPI_Bcast(n_tokens, 1, MPI_INT32_T, 0, ctx_mpi->comm); -// fprintf(stderr, "Node %d, old_n_tokens: %d, new n_tokens: %d\n", ctx_mpi->rank, old_n_tokens, *n_tokens); - // If what was passed in differs from what was broadcast, - // we can't guarantee the allocated sizes are correct - // TODO check how often this is done and if it's a problem, - // try to allocate ahead of time - if (old_n_tokens != *n_tokens) { - *pos = static_cast(realloc(*pos, *n_tokens * sizeof(int32_t))); - *n_seq_ids = static_cast(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t))); - *logits = static_cast(realloc(*logits, *n_tokens * sizeof(int32_t))); + ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS); + int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t)); + + if (ctx_mpi->rank == 0 && *logits != nullptr) { + ggml_mpi_sync_pipelined(ctx_mpi, *logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS); + } else { + ggml_mpi_sync_pipelined(ctx_mpi, temp_logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS); } -// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm); - MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); + if (ctx_mpi->rank != 0) { + bool should_set_batch_logits = false; + for (int i = 0; i < *n_tokens; i++) { + if (temp_logits[i]) { + should_set_batch_logits = true; + break; + } + } + if (should_set_batch_logits) { + if (*logits != NULL) { + free(*logits); + *logits = NULL; + } + *logits = temp_logits; + } else { + if (*logits != NULL) { + free(*logits); + *logits = NULL; + } + free(temp_logits); + } + } else { + free(temp_logits); + } + + // For now, we assume that the pos, seq_ids, tokens, etc have been + // pre-allocated for the largest possible sizes, even on worker nodes. + //if (old_n_tokens != *n_tokens) { + // *pos = realloc(*pos, *n_tokens * sizeof(int32_t)); + // *n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t )); + // *tokens = realloc(*tokens, *n_tokens * sizeof(int32_t )); + //} + + GGML_ASSERT(n_seq_ids != nullptr); + GGML_ASSERT(n_tokens != nullptr); + + + // FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend +// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS); // We need to know the total number of sequence // ids, so we count them all up - int32_t total_n_seq_ids = 0; - for (int32_t i = 0; i < *n_tokens; i++) { - total_n_seq_ids += (*n_seq_ids)[i]; - } - - // MPI can't chase the pointers for multidimensional arrays, so we flatten them first - // for transit - auto * flattened_seq_ids = static_cast(calloc(total_n_seq_ids, sizeof(int32_t))); - - int32_t current_index = 0; - - // Only rank 0 needs to flatten since the others don't have the real seq_id - if (ctx_mpi->rank == 0) { - for (int32_t i = 0; i < *n_tokens; i++) { - for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { - flattened_seq_ids[current_index] = (*seq_id)[i][j]; - current_index++; - } - } - } - - - MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); - MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm); - //MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm); - auto ** new_seq_id = static_cast(calloc(*n_tokens, sizeof(int32_t *))); - current_index = 0; - for (int32_t i = 0; i < *n_tokens; i++) { - new_seq_id[i] = static_cast(calloc((*n_seq_ids)[i], sizeof(int32_t))); - for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { - new_seq_id[i][j] = flattened_seq_ids[current_index]; - current_index++; - } - } - free(flattened_seq_ids); - //free(*seq_id); // <- something is still holding onto this, need to investigate - *seq_id = new_seq_id; +// int32_t total_n_seq_ids = 0; +// for (int32_t i = 0; i < *n_tokens; i++) { +// total_n_seq_ids += (*n_seq_ids)[i]; +// } +// +// // MPI can't chase the pointers for multidimensional arrays, so we flatten them first +// // for transit +// int32_t * flattened_seq_ids = static_cast(calloc(total_n_seq_ids, sizeof(int32_t))); +// +// int32_t current_index = 0; +// +// // Only rank 0 needs to flatten since the others don't have the real seq_id +// if (ctx_mpi->rank == 0) { +// for (int32_t i = 0; i < *n_tokens; i++) { +// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { +// flattened_seq_ids[current_index] = (*seq_id)[i][j]; +// current_index++; +// } +// } +// } +// +// +// +// ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS); +// ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS); +// +// current_index = 0; +// for (int32_t i = 0; i < *n_tokens; i++) { +// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { +// (*seq_id)[i][j] = flattened_seq_ids[current_index]; +// current_index++; +// } +// +// } +// free(flattened_seq_ids); } + void ggml_mpi_synch_int( struct ggml_mpi_context * ctx_mpi, int32_t * val diff --git a/ggml-mpi.h b/ggml-mpi.h index 0a5d69e14..6497f47c8 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -12,6 +12,50 @@ struct ggml_cgraph; extern "C" { #endif +#define GGML_MPI_DECODE 0 + +#define GGML_MPI_KV_CLEAR 1 + +#define GGML_MPI_KV_SEQ_RM 2 + +#define GGML_MPI_KV_SEQ_CP 3 + +#define GGML_MPI_KV_SEQ_KEEP 4 + +#define GGML_MPI_KV_SEQ_SHIFT 5 + +#define GGML_MPI_SHUTDOWN 6 + +#define GGML_MPI_TRANSFER_TENSORS 7 + +#define GGML_MPI_SYNC_LOGITS 8 + +#define GGML_MPI_CANCEL_RUN 9 + +#define GGML_MPI_KV_SEQ_CP_BACK 10 + +#define GGML_MPI_TRANS_ID 11 + +#define GGML_MPI_BATCH_ID 12 + +#define GGML_MPI_N_TOKENS 13 + +#define GGML_MPI_TOKENS 14 + +#define GGML_MPI_N_SEQ_IDS 15 + +#define GGML_MPI_SEQ_IDS 16 + +#define GGML_MPI_POS 17 + +#define GGML_MPI_BEGIN_TRANSACTION 18 + +#define GGML_MPI_MAX_N_SEQ 19 + +#define GGML_MPI_BATCH_LOGITS 20 + + + /** * The context used for MPI operations, * a program may make use of more than one @@ -131,7 +175,8 @@ void ggml_mpi_eval_init( int32_t ** pos, int32_t ** n_seq_ids, int32_t *** seq_id, - int8_t ** logits); + int8_t ** logits, + uint32_t n_seq_max); void ggml_mpi_synch_int( struct ggml_mpi_context * ctx_mpi, diff --git a/llama.cpp b/llama.cpp index 0c0c783b1..9e0343cad 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1750,6 +1750,7 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + uint32_t n_seq_max; }; struct llama_layer { @@ -8812,7 +8813,7 @@ static int llama_decode_internal( uint32_t n_tokens_all = batch_all.n_tokens; #ifdef GGML_USE_MPI - ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits)); + ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max); n_tokens_all = batch_all.n_tokens; #endif @@ -8896,7 +8897,7 @@ static int llama_decode_internal( seq_id_arr.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { n_seq_id[i] = 1; - seq_id[i].resize(1); + seq_id[i].resize(lctx.cparams.n_seq_max); seq_id[i][0] = u_batch.all_seq_id; seq_id_arr[i] = seq_id[i].data(); } @@ -12753,6 +12754,9 @@ void llama_backend_init(void) { ggml_free(ctx); } +#ifdef GGML_USE_MPI + ggml_mpi_backend_init(); +#endif } @@ -12760,10 +12764,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) { if (numa != GGML_NUMA_STRATEGY_DISABLED) { ggml_numa_init(numa); } - -#ifdef GGML_USE_MPI - ggml_mpi_backend_init(); -#endif } void llama_backend_free(void) { @@ -12844,7 +12844,7 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - // TODO: maybe add n_seq_max here too + cparams.n_seq_max = params.n_seq_max; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -13984,6 +13984,13 @@ void llama_batch_free(struct llama_batch batch) { free(batch.seq_id); } if (batch.logits) free(batch.logits); + + batch.token = nullptr; + batch.embd = nullptr; + batch.pos = nullptr; + batch.n_seq_id = nullptr; + batch.seq_id = nullptr; + batch.logits = nullptr; } int32_t llama_decode(