Pipeline KV operations

This commit is contained in:
Branden Butler 2024-03-19 15:01:39 -05:00
parent d2de181d95
commit 9419190533
3 changed files with 98 additions and 10 deletions

View file

@ -126,7 +126,7 @@ void ggml_mpi_sync_pipelined(
}
if(ctx_mpi->rank < ctx_mpi->size - 1) {
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
GGML_ASSERT(val != nullptr);
GGML_ASSERT(val != nullptr || count == 0);
GGML_ASSERT(count < 128*1024*1024);
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);

View file

@ -22,7 +22,7 @@ extern "C" {
#define GGML_MPI_KV_SEQ_KEEP 4
#define GGML_MPI_KV_SEQ_SHIFT 5
#define GGML_MPI_KV_SEQ_ADD 5
#define GGML_MPI_SHUTDOWN 6
@ -54,6 +54,8 @@ extern "C" {
#define GGML_MPI_BATCH_LOGITS 20
#define GGML_MPI_KV_SEQ_DIV 21
/**

102
llama.cpp
View file

@ -13845,14 +13845,56 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
}
void llama_kv_cache_clear(struct llama_context * ctx) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_CLEAR;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, nullptr, 0, GGML_MPI_KV_CLEAR);
#endif
llama_kv_cache_clear(ctx->kv_self);
}
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_RM;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[3] = {seq_id, p0, p1};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 3, GGML_MPI_KV_SEQ_RM);
seq_id = vals[0];
p0 = vals[1];
p1 = vals[2];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
// }
#endif
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_CP;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP);
// if(ggml_mpi_recv_trans_id(ctx->model.ctx_mpi) < ggml_mpi_trans_id(ctx->model.ctx_mpi)) {
//// return;
// }
// ggml_mpi_inc_trans_id(ctx->model.ctx_mpi);
seq_id_src = vals[0];
seq_id_dst = vals[1];
p0 = vals[2];
p1 = vals[3];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nCopying sequence %d to sequence %d from %d to %d\n", seq_id_src, seq_id_dst, p0, p1);
// }
#endif
if (seq_id_src == seq_id_dst) {
return;
}
@ -13860,10 +13902,35 @@ void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src,
}
void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_KEEP;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[1] = {seq_id};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 1, GGML_MPI_KV_SEQ_KEEP);
seq_id = vals[0];
#endif
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
}
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_ADD;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[4] = {seq_id, p0, p1, delta};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_ADD);
seq_id = vals[0];
p0 = vals[1];
p1 = vals[2];
delta = vals[3];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
// }
#endif
if (delta == 0) {
return;
}
@ -13872,6 +13939,22 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla
}
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_KV_SEQ_DIV;
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
int32_t vals[4] = {seq_id, p0, p1, d};
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_DIV);
seq_id = vals[0];
p0 = vals[1];
p1 = vals[2];
d = vals[3];
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
// }
#endif
if (d == 1) {
return;
}
@ -14393,7 +14476,7 @@ int llama_process_mpi_transaction(
struct llama_context * ctx,
struct llama_batch & batch,
int tag) {
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) {
// printf("\nBeginning transaction type %d\n", tag);
// }
@ -14414,12 +14497,15 @@ int llama_process_mpi_transaction(
// case GGML_MPI_KV_SEQ_CP_BACK:
// llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0);
// break;
// case GGML_MPI_KV_SEQ_KEEP:
// llama_kv_cache_seq_keep(ctx, 0);
// break;
// case GGML_MPI_KV_SEQ_SHIFT:
// llama_kv_cache_seq_shift(ctx, 0, 0, 0, 0);
// break;
case GGML_MPI_KV_SEQ_KEEP:
llama_kv_cache_seq_keep(ctx, 0);
break;
case GGML_MPI_KV_SEQ_ADD:
llama_kv_cache_seq_add(ctx, 0, 0, 0, 0);
break;
case GGML_MPI_KV_SEQ_DIV:
llama_kv_cache_seq_div(ctx, 0, 0, 0, 0);
break;
default:
printf("Unknown operation, exiting\n");
exit(1);
@ -14435,7 +14521,7 @@ int llama_process_mpi_worker(
int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi);
int32_t count;
int32_t trans_type;
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) {
// printf("\nReceived command %d\n", tag);
// }
switch (tag) {