diff --git a/ggml-mpi.cpp b/ggml-mpi.cpp index 7fb4e752c..9fa3e6717 100644 --- a/ggml-mpi.cpp +++ b/ggml-mpi.cpp @@ -6,8 +6,8 @@ #include -#include -#include +#include +#include #include #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -24,6 +24,8 @@ struct ggml_mpi_context { MPI_Comm comm; int layer_start; int layer_end; + MPI_Status status; + struct ggml_tensor *inp0; std::string name; struct ggml_backend * wrapped_backend; @@ -31,6 +33,8 @@ struct ggml_mpi_context { ggml_backend_sched_t scheduler; bool remote; void* send_buffer; + int trans_id; + int recv_trans_id; }; void ggml_mpi_backend_init(void) { @@ -122,12 +126,43 @@ void ggml_mpi_sync_pipelined( } if(ctx_mpi->rank < ctx_mpi->size - 1) { GGML_ASSERT(ctx_mpi->send_buffer != nullptr); + GGML_ASSERT(val != nullptr); + GGML_ASSERT(count < 128*1024*1024); + const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm); GGML_ASSERT(retval == MPI_SUCCESS); } } +void ggml_mpi_barrier(struct ggml_mpi_context * ctx_mpi) { + MPI_Barrier(ctx_mpi->comm); +} + +void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag) { + MPI_Probe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &(ctx_mpi->status)); +} + +int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return 0; + } + + int ret; + MPI_Iprobe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &ret, &(ctx_mpi->status)); + return ret; +} + +int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi) { + return ctx_mpi->status.MPI_TAG; +} + +int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi) { + int32_t count; + MPI_Get_count(&ctx_mpi->status, MPI_INT32_T, &count); + return count; +} + void ggml_mpi_eval_init( struct ggml_mpi_context * ctx_mpi, int32_t * n_tokens, @@ -142,8 +177,15 @@ void ggml_mpi_eval_init( return; } - + int32_t old_n_tokens = *n_tokens; ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS); + + if (old_n_tokens != *n_tokens) { + *pos = static_cast(realloc(*pos, *n_tokens * sizeof(int32_t))); + *n_seq_ids = static_cast(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t))); + *logits = static_cast(realloc(*logits, *n_tokens * sizeof(int32_t))); + } + int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t)); if (ctx_mpi->rank == 0 && *logits != nullptr) { @@ -183,49 +225,51 @@ void ggml_mpi_eval_init( // pre-allocated for the largest possible sizes, even on worker nodes. GGML_ASSERT(n_seq_ids != nullptr); + GGML_ASSERT(*n_seq_ids != nullptr); + GGML_ASSERT(n_tokens != nullptr); // FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend -// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS); + ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS); // We need to know the total number of sequence // ids, so we count them all up -// int32_t total_n_seq_ids = 0; -// for (int32_t i = 0; i < *n_tokens; i++) { -// total_n_seq_ids += (*n_seq_ids)[i]; -// } -// -// // MPI can't chase the pointers for multidimensional arrays, so we flatten them first -// // for transit -// int32_t * flattened_seq_ids = static_cast(calloc(total_n_seq_ids, sizeof(int32_t))); -// -// int32_t current_index = 0; -// -// // Only rank 0 needs to flatten since the others don't have the real seq_id -// if (ctx_mpi->rank == 0) { -// for (int32_t i = 0; i < *n_tokens; i++) { -// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { -// flattened_seq_ids[current_index] = (*seq_id)[i][j]; -// current_index++; -// } -// } -// } -// -// -// -// ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS); -// ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS); -// -// current_index = 0; -// for (int32_t i = 0; i < *n_tokens; i++) { -// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { -// (*seq_id)[i][j] = flattened_seq_ids[current_index]; -// current_index++; -// } -// -// } -// free(flattened_seq_ids); + int32_t total_n_seq_ids = 0; + for (int32_t i = 0; i < *n_tokens; i++) { + total_n_seq_ids += (*n_seq_ids)[i]; + } + + // MPI can't chase the pointers for multidimensional arrays, so we flatten them first + // for transit + int32_t * flattened_seq_ids = static_cast(calloc(total_n_seq_ids, sizeof(int32_t))); + + int32_t current_index = 0; + + // Only rank 0 needs to flatten since the others don't have the real seq_id + if (ctx_mpi->rank == 0) { + for (int32_t i = 0; i < *n_tokens; i++) { + for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { + flattened_seq_ids[current_index] = (*seq_id)[i][j]; + current_index++; + } + } + } + + + + ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS); + ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS); + + current_index = 0; + for (int32_t i = 0; i < *n_tokens; i++) { + for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { + (*seq_id)[i][j] = flattened_seq_ids[current_index]; + current_index++; + } + + } + free(flattened_seq_ids); } @@ -236,6 +280,19 @@ void ggml_mpi_sync_int( MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm); } +void ggml_mpi_sync_ints_pipelined( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +) { + ggml_mpi_sync_pipelined(ctx_mpi, vals, count, MPI_INT32_T, tag); + int old_trans = ctx_mpi->trans_id; + ggml_mpi_sync_pipelined(ctx_mpi, &ctx_mpi->trans_id, 1, MPI_INT32_T, GGML_MPI_TRANS_ID); + ctx_mpi->recv_trans_id = ctx_mpi->trans_id; + ctx_mpi->trans_id = old_trans; +} + static void ggml_mpi_tensor_send(const struct ggml_tensor * t, const void* data, int mpi_rank_dst, MPI_Comm comm) { MPI_Datatype mpi_type; @@ -549,6 +606,8 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t } } + // TODO exploding memory usage cause we replace the buffer with the wrapped buffer, + // but don't free the contexts, and then create new ones when we re-wrap if (!ctx->remote) { diff --git a/ggml-mpi.h b/ggml-mpi.h index fe8358f2d..e7880b704 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -100,6 +100,26 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type( GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf); +void ggml_mpi_sync_ints_pipelined( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +); + +void ggml_mpi_sync_ints_pipelined_back( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +); +// clear = 1, rm = 2, cp = 3, keep = 4, seq_shift = 5 +void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag); +int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi); + +int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag); +int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi); + /** * Create a new context by splitting the given context's * communicator, creating a "sub-communicator." This is a collective diff --git a/llama.cpp b/llama.cpp index d20c17677..5d7d10318 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9064,11 +9064,58 @@ static void llama_graph_compute( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch_all) { // TODO: rename back to batch + llama_batch & batch_all) { // TODO: rename back to batch +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(lctx.model.ctx_mpi) == 0 && ggml_mpi_size(lctx.model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_DECODE; + ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } +// ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.batch_id, 1, GGML_MPI_BATCH_ID); + int old_tokens = batch_all.n_tokens; + + ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.n_tokens, 1, GGML_MPI_N_TOKENS); + + ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, reinterpret_cast(&lctx.cparams.n_seq_max), 1, GGML_MPI_MAX_N_SEQ); + if (ggml_mpi_rank(lctx.model.ctx_mpi) > 0) { + int new_n_tokens = batch_all.n_tokens; + llama_batch_free(batch_all); + batch_all = llama_batch_init(new_n_tokens, 0, (int32_t)lctx.cparams.n_seq_max); + } +#endif + uint32_t n_tokens_all = batch_all.n_tokens; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; + + if (batch_all.pos == nullptr) { + pos.resize(n_tokens_all); + for (uint32_t i = 0; i < n_tokens_all; i++) { + pos[i] = batch_all.all_pos_0 + i*batch_all.all_pos_1; + } + + batch_all.pos = pos.data(); + } + + if (batch_all.seq_id == nullptr) { + n_seq_id.resize(n_tokens_all); + seq_id.resize(n_tokens_all); + seq_id_arr.resize(n_tokens_all); + for (uint32_t i = 0; i < n_tokens_all; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(lctx.cparams.n_seq_max); + seq_id[i][0] = batch_all.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } + + batch_all.n_seq_id = n_seq_id.data(); + batch_all.seq_id = seq_id_arr.data(); + } + #ifdef GGML_USE_MPI ggml_mpi_eval_init(lctx.model.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max); n_tokens_all = batch_all.n_tokens; @@ -9114,10 +9161,6 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); @@ -14344,6 +14387,87 @@ void llama_batch_free(struct llama_batch batch) { batch.logits = nullptr; } +#ifdef GGML_USE_MPI + +int llama_process_mpi_transaction( + struct llama_context * ctx, + struct llama_batch & batch, + int tag) { +// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) { +// printf("\nBeginning transaction type %d\n", tag); +// } + + switch (tag) { + case GGML_MPI_DECODE: +// llama_batch_free(batch); + return llama_decode_internal(*ctx, batch); + break; + case GGML_MPI_KV_CLEAR: + llama_kv_cache_clear(ctx); + break; + case GGML_MPI_KV_SEQ_RM: + llama_kv_cache_seq_rm(ctx, 1, -1, -1); + break; + case GGML_MPI_KV_SEQ_CP: + llama_kv_cache_seq_cp(ctx, 0, 0, 0, 0); + break; +// case GGML_MPI_KV_SEQ_CP_BACK: +// llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0); +// break; +// case GGML_MPI_KV_SEQ_KEEP: +// llama_kv_cache_seq_keep(ctx, 0); +// break; +// case GGML_MPI_KV_SEQ_SHIFT: +// llama_kv_cache_seq_shift(ctx, 0, 0, 0, 0); +// break; + default: + printf("Unknown operation, exiting\n"); + exit(1); + break; + } + return 0; +} + +int llama_process_mpi_worker( + struct llama_context * ctx, + struct llama_batch & batch) { + ggml_mpi_probe(ctx->model.ctx_mpi, -1, -1); + int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi); + int32_t count; + int32_t trans_type; +// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) { +// printf("\nReceived command %d\n", tag); +// } + switch (tag) { + case GGML_MPI_BEGIN_TRANSACTION: + + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &trans_type, 1, GGML_MPI_BEGIN_TRANSACTION); + return llama_process_mpi_transaction(ctx, batch, trans_type); + break; + case GGML_MPI_SHUTDOWN: + llama_free(ctx); + llama_backend_free(); + exit(0); + break; + case GGML_MPI_CANCEL_RUN: +// count = ggml_mpi_status_count_int32(ctx->model.ctx_mpi); +//// printf("Received cancel run\n"); +// { +// std::vector canceled(count, -1); +// llama_cancel_run(ctx, canceled.data(), canceled.size()); +// +// } +// break; + default: + printf("Unknown operation, exiting\n"); + exit(1); + break; + } + return 0; +} + +#endif + int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { @@ -14351,9 +14475,7 @@ int32_t llama_decode( #ifdef GGML_USE_MPI if (ggml_mpi_rank(ctx->model.ctx_mpi) > 0) { // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - const int n_ctx = llama_n_ctx(ctx); - std::vector tmp(n_ctx, llama_token_bos(&ctx->model)); - while (llama_decode_internal(*ctx, batch) >= 0){}; + while (llama_process_mpi_worker(ctx, batch) >= 0){}; llama_backend_free(); exit(1); }