Pipeline KV operations
This commit is contained in:
parent
d2de181d95
commit
9419190533
3 changed files with 98 additions and 10 deletions
|
@ -126,7 +126,7 @@ void ggml_mpi_sync_pipelined(
|
|||
}
|
||||
if(ctx_mpi->rank < ctx_mpi->size - 1) {
|
||||
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
|
||||
GGML_ASSERT(val != nullptr);
|
||||
GGML_ASSERT(val != nullptr || count == 0);
|
||||
GGML_ASSERT(count < 128*1024*1024);
|
||||
|
||||
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
|
||||
|
|
|
@ -22,7 +22,7 @@ extern "C" {
|
|||
|
||||
#define GGML_MPI_KV_SEQ_KEEP 4
|
||||
|
||||
#define GGML_MPI_KV_SEQ_SHIFT 5
|
||||
#define GGML_MPI_KV_SEQ_ADD 5
|
||||
|
||||
#define GGML_MPI_SHUTDOWN 6
|
||||
|
||||
|
@ -54,6 +54,8 @@ extern "C" {
|
|||
|
||||
#define GGML_MPI_BATCH_LOGITS 20
|
||||
|
||||
#define GGML_MPI_KV_SEQ_DIV 21
|
||||
|
||||
|
||||
|
||||
/**
|
||||
|
|
102
llama.cpp
102
llama.cpp
|
@ -13845,14 +13845,56 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
|
|||
}
|
||||
|
||||
void llama_kv_cache_clear(struct llama_context * ctx) {
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
int transaction_type = GGML_MPI_KV_CLEAR;
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
}
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, nullptr, 0, GGML_MPI_KV_CLEAR);
|
||||
#endif
|
||||
llama_kv_cache_clear(ctx->kv_self);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
int transaction_type = GGML_MPI_KV_SEQ_RM;
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
}
|
||||
int32_t vals[3] = {seq_id, p0, p1};
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 3, GGML_MPI_KV_SEQ_RM);
|
||||
seq_id = vals[0];
|
||||
p0 = vals[1];
|
||||
p1 = vals[2];
|
||||
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
|
||||
// }
|
||||
#endif
|
||||
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
int transaction_type = GGML_MPI_KV_SEQ_CP;
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
}
|
||||
|
||||
int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1};
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP);
|
||||
// if(ggml_mpi_recv_trans_id(ctx->model.ctx_mpi) < ggml_mpi_trans_id(ctx->model.ctx_mpi)) {
|
||||
//// return;
|
||||
// }
|
||||
// ggml_mpi_inc_trans_id(ctx->model.ctx_mpi);
|
||||
seq_id_src = vals[0];
|
||||
seq_id_dst = vals[1];
|
||||
p0 = vals[2];
|
||||
p1 = vals[3];
|
||||
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
// printf("\nCopying sequence %d to sequence %d from %d to %d\n", seq_id_src, seq_id_dst, p0, p1);
|
||||
// }
|
||||
#endif
|
||||
|
||||
if (seq_id_src == seq_id_dst) {
|
||||
return;
|
||||
}
|
||||
|
@ -13860,10 +13902,35 @@ void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src,
|
|||
}
|
||||
|
||||
void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
int transaction_type = GGML_MPI_KV_SEQ_KEEP;
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
}
|
||||
int32_t vals[1] = {seq_id};
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 1, GGML_MPI_KV_SEQ_KEEP);
|
||||
seq_id = vals[0];
|
||||
#endif
|
||||
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
int transaction_type = GGML_MPI_KV_SEQ_ADD;
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
}
|
||||
int32_t vals[4] = {seq_id, p0, p1, delta};
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_ADD);
|
||||
seq_id = vals[0];
|
||||
p0 = vals[1];
|
||||
p1 = vals[2];
|
||||
delta = vals[3];
|
||||
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
|
||||
// }
|
||||
#endif
|
||||
|
||||
if (delta == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -13872,6 +13939,22 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla
|
|||
}
|
||||
|
||||
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->model.ctx_mpi) == 0 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
int transaction_type = GGML_MPI_KV_SEQ_DIV;
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
}
|
||||
int32_t vals[4] = {seq_id, p0, p1, d};
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_DIV);
|
||||
seq_id = vals[0];
|
||||
p0 = vals[1];
|
||||
p1 = vals[2];
|
||||
d = vals[3];
|
||||
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1 && ggml_mpi_size(ctx->model.ctx_mpi) > 1) {
|
||||
// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1);
|
||||
// }
|
||||
#endif
|
||||
|
||||
if (d == 1) {
|
||||
return;
|
||||
}
|
||||
|
@ -14393,7 +14476,7 @@ int llama_process_mpi_transaction(
|
|||
struct llama_context * ctx,
|
||||
struct llama_batch & batch,
|
||||
int tag) {
|
||||
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
|
||||
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) {
|
||||
// printf("\nBeginning transaction type %d\n", tag);
|
||||
// }
|
||||
|
||||
|
@ -14414,12 +14497,15 @@ int llama_process_mpi_transaction(
|
|||
// case GGML_MPI_KV_SEQ_CP_BACK:
|
||||
// llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0);
|
||||
// break;
|
||||
// case GGML_MPI_KV_SEQ_KEEP:
|
||||
// llama_kv_cache_seq_keep(ctx, 0);
|
||||
// break;
|
||||
// case GGML_MPI_KV_SEQ_SHIFT:
|
||||
// llama_kv_cache_seq_shift(ctx, 0, 0, 0, 0);
|
||||
// break;
|
||||
case GGML_MPI_KV_SEQ_KEEP:
|
||||
llama_kv_cache_seq_keep(ctx, 0);
|
||||
break;
|
||||
case GGML_MPI_KV_SEQ_ADD:
|
||||
llama_kv_cache_seq_add(ctx, 0, 0, 0, 0);
|
||||
break;
|
||||
case GGML_MPI_KV_SEQ_DIV:
|
||||
llama_kv_cache_seq_div(ctx, 0, 0, 0, 0);
|
||||
break;
|
||||
default:
|
||||
printf("Unknown operation, exiting\n");
|
||||
exit(1);
|
||||
|
@ -14435,7 +14521,7 @@ int llama_process_mpi_worker(
|
|||
int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi);
|
||||
int32_t count;
|
||||
int32_t trans_type;
|
||||
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
|
||||
// if (ggml_mpi_rank(ctx->model.ctx_mpi) == ggml_mpi_size(ctx->model.ctx_mpi) - 1) {
|
||||
// printf("\nReceived command %d\n", tag);
|
||||
// }
|
||||
switch (tag) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue