From bcfb190c2865389f6805cc89b4709c2a270002f7 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Sun, 29 Oct 2023 15:16:16 -0500 Subject: [PATCH] Synchronize batch sequence info, fixing MPI for llama_decode() --- common/common.cpp | 2 +- ggml-mpi.c | 63 ++++++++++++++++++++++++++++++++++++++++++----- ggml-mpi.h | 21 +++++++++++----- llama.cpp | 29 +++++++++++++++------- 4 files changed, 93 insertions(+), 22 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index a6bdae68f..c58477fd6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1403,7 +1403,7 @@ std::tuple llama_init_from_gpt_par LOG("warming up the model with an empty run\n"); #ifndef GGML_USE_MPI - // When using MPI, llama_eval() enters into an infinite loop + // When using MPI, llama_decode() enters into an infinite loop // on non-head nodes. Thus, we only want to warmup the model here // if we aren't using MPI. // FIXME have a way to terminate the infinite loop so we can warmup the model diff --git a/ggml-mpi.c b/ggml-mpi.c index 9217651d6..1e4d0b376 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -60,16 +60,67 @@ int ggml_mpi_size(struct ggml_mpi_context * ctx) { } void ggml_mpi_eval_init( - struct ggml_mpi_context * ctx_mpi, - int * n_tokens, - int * n_past, - int * n_threads) { + struct ggml_mpi_context * ctx_mpi, + int32_t * n_tokens, + int32_t ** pos, + int32_t ** n_seq_ids, + int32_t *** seq_id, + int8_t ** logits) { MPI_Barrier(ctx_mpi->comm); - MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm); - MPI_Bcast(n_past, 1, MPI_INT, 0, ctx_mpi->comm); + MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm); + + if (ctx_mpi->rank != 0) { + *pos = calloc(*n_tokens, sizeof(int32_t)); + *n_seq_ids = calloc(*n_tokens, sizeof(int32_t)); + *logits = calloc(*n_tokens, sizeof(int8_t)); + } + + int32_t total_n_seq_ids = 0; + for (size_t i = 0; i < *n_tokens; i++) { + total_n_seq_ids += (*n_seq_ids)[i]; + } + + MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm); + MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); + + int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t)); + + int32_t current_index = 0; + + if (ctx_mpi->rank == 0) { + for (size_t i = 0; i < *n_tokens; i++) { + for (size_t j = 0; j < (*n_seq_ids)[i]; j++) { + flattened_seq_ids[current_index] = (*seq_id)[i][j]; + current_index++; + } + } + } + + + MPI_Bcast(*pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); + MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm); + //MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm); + int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*)); + current_index = 0; + for (size_t i = 0; i < *n_tokens; i++) { + new_seq_id[i] = calloc((*n_seq_ids)[i], sizeof(int32_t)); + for (size_t j = 0; j < (*n_seq_ids)[i]; j++) { + new_seq_id[i][j] = flattened_seq_ids[current_index]; + current_index++; + } + } + free(flattened_seq_ids); + *seq_id = new_seq_id; +} + +void ggml_mpi_synch_int( + struct ggml_mpi_context * ctx_mpi, + int32_t * val +) { + MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm); } static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { diff --git a/ggml-mpi.h b/ggml-mpi.h index 7eeb3856f..f3c4bf2aa 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -110,14 +110,23 @@ int ggml_mpi_size(struct ggml_mpi_context * ctx); * * @param ctx_mpi The context in which to prepare for evaluation. * @param n_tokens A pointer to the n_tokens, which will be synchronized after this function. - * @param n_past A pointer to the n_past, which will be synchronized after this function. - * @param n_threads A pointer to the n_threads, which is unused currently. + * @param pos A pointer to the pos array, which will be synchronized after this function. + * @param n_seq_ids A pointer to the n_seq_ids array, which will be synchronized after this function. + * @param seq_id A pointer to the seq_id 2D array, which will be synchronized after this function. + * @param logits A pointer to the logits array, which is unused currently since only node 0 needs them. */ void ggml_mpi_eval_init( - struct ggml_mpi_context * ctx_mpi, - int * n_tokens, - int * n_past, - int * n_threads); + struct ggml_mpi_context * ctx_mpi, + int32_t * n_tokens, + int32_t ** pos, + int32_t ** n_seq_ids, + int32_t *** seq_id, + int8_t ** logits); + +void ggml_mpi_synch_int( + struct ggml_mpi_context * ctx_mpi, + int32_t * val + ); /** * Split a range across all nodes within the given diff --git a/llama.cpp b/llama.cpp index 98ffa1075..a5f56b552 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8776,8 +8776,7 @@ static int llama_decode_internal( llama_context & lctx, llama_batch batch_all) { // TODO: rename back to batch - const uint32_t n_tokens_all = batch_all.n_tokens; - + uint32_t n_tokens_all = batch_all.n_tokens; if (n_tokens_all == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); return -1; @@ -8798,11 +8797,7 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; -#ifdef GGML_USE_MPI - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); -#endif + auto & kv_self = lctx.kv_self; @@ -8828,7 +8823,7 @@ static int llama_decode_internal( std::vector> seq_id; for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); + uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); llama_batch u_batch = { /* .n_tokens = */ (int32_t) n_tokens, /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, @@ -8881,7 +8876,12 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + #ifdef GGML_USE_MPI + // TODO: needs fix after #3228 + ggml_mpi_eval_init(lctx.ctx_mpi, &(u_batch.n_tokens), &(u_batch.pos), &(u_batch.n_seq_id), &(u_batch.seq_id), &(u_batch.logits)); + n_tokens = u_batch.n_tokens; +#endif + if (!llama_kv_cache_find_slot(kv_self, u_batch)) { return 1; } @@ -13923,6 +13923,17 @@ void llama_batch_free(struct llama_batch batch) { int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { + +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const int n_ctx = llama_n_ctx(ctx); + std::vector tmp(n_ctx, llama_token_bos(&ctx->model)); + while (llama_decode_internal(*ctx, batch) >= 0){}; + llama_backend_free(); + exit(1); + } +#endif const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);