From 9419190533f188be8a165519ee9a622d858381a9 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Tue, 19 Mar 2024 15:01:39 -0500 Subject: [PATCH] Pipeline KV operations --- ggml-mpi.cpp | 2 +- ggml-mpi.h | 4 +- llama.cpp | 102 +++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 98 insertions(+), 10 deletions(-) diff --git a/ggml-mpi.cpp b/ggml-mpi.cpp index 9fa3e6717..f8c87f2d6 100644 --- a/ggml-mpi.cpp +++ b/ggml-mpi.cpp @@ -126,7 +126,7 @@ void ggml_mpi_sync_pipelined( } if(ctx_mpi->rank < ctx_mpi->size - 1) { GGML_ASSERT(ctx_mpi->send_buffer != nullptr); - GGML_ASSERT(val != nullptr); + GGML_ASSERT(val != nullptr || count == 0); GGML_ASSERT(count < 128*1024*1024); const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm); diff --git a/ggml-mpi.h b/ggml-mpi.h index e7880b704..d988a81e4 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -22,7 +22,7 @@ extern "C" { #define GGML_MPI_KV_SEQ_KEEP 4 -#define GGML_MPI_KV_SEQ_SHIFT 5 +#define GGML_MPI_KV_SEQ_ADD 5 #define GGML_MPI_SHUTDOWN 6 @@ -54,6 +54,8 @@ extern "C" { #define GGML_MPI_BATCH_LOGITS 20 +#define GGML_MPI_KV_SEQ_DIV 21 + /** diff --git a/llama.cpp b/llama.cpp index 5d7d10318..a6392237b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13845,14 +13845,56 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { } void llama_kv_cache_clear(struct llama_context * ctx) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_CLEAR; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, nullptr, 0, GGML_MPI_KV_CLEAR); +#endif llama_kv_cache_clear(ctx->kv_self); } bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_RM; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[3] = {seq_id, p0, p1}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 3, GGML_MPI_KV_SEQ_RM); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1); +// } +#endif return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_CP; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + + int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP); +// if(ggml_mpi_recv_trans_id(ctx->model.ctx_mpi) < ggml_mpi_trans_id(ctx->model.ctx_mpi)) { +//// return; +// } +// ggml_mpi_inc_trans_id(ctx->model.ctx_mpi); + seq_id_src = vals[0]; + seq_id_dst = vals[1]; + p0 = vals[2]; + p1 = vals[3]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nCopying sequence %d to sequence %d from %d to %d\n", seq_id_src, seq_id_dst, p0, p1); +// } +#endif + if (seq_id_src == seq_id_dst) { return; } @@ -13860,10 +13902,35 @@ void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, } void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_KEEP; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[1] = {seq_id}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 1, GGML_MPI_KV_SEQ_KEEP); + seq_id = vals[0]; +#endif llama_kv_cache_seq_keep(ctx->kv_self, seq_id); } void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_ADD; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[4] = {seq_id, p0, p1, delta}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_ADD); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; + delta = vals[3]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1); +// } +#endif + if (delta == 0) { return; } @@ -13872,6 +13939,22 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla } void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_DIV; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[4] = {seq_id, p0, p1, d}; + ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_DIV); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; + d = vals[3]; +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) { +// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1); +// } +#endif + if (d == 1) { return; } @@ -14393,7 +14476,7 @@ int llama_process_mpi_transaction( struct llama_context * ctx, struct llama_batch & batch, int tag) { -// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) { +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) { // printf("\nBeginning transaction type %d\n", tag); // } @@ -14414,12 +14497,15 @@ int llama_process_mpi_transaction( // case GGML_MPI_KV_SEQ_CP_BACK: // llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0); // break; -// case GGML_MPI_KV_SEQ_KEEP: -// llama_kv_cache_seq_keep(ctx, 0); -// break; -// case GGML_MPI_KV_SEQ_SHIFT: -// llama_kv_cache_seq_shift(ctx, 0, 0, 0, 0); -// break; + case GGML_MPI_KV_SEQ_KEEP: + llama_kv_cache_seq_keep(ctx, 0); + break; + case GGML_MPI_KV_SEQ_ADD: + llama_kv_cache_seq_add(ctx, 0, 0, 0, 0); + break; + case GGML_MPI_KV_SEQ_DIV: + llama_kv_cache_seq_div(ctx, 0, 0, 0, 0); + break; default: printf("Unknown operation, exiting\n"); exit(1); @@ -14435,7 +14521,7 @@ int llama_process_mpi_worker( int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi); int32_t count; int32_t trans_type; -// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) { +// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) { // printf("\nReceived command %d\n", tag); // } switch (tag) {