Port transactions from mpi-speculative, fix incorrect seq_id syncing (not tested)
This commit is contained in:
parent
57ac2e7268
commit
d2de181d95
3 changed files with 248 additions and 47 deletions
137
ggml-mpi.cpp
137
ggml-mpi.cpp
|
@ -6,8 +6,8 @@
|
|||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
@ -24,6 +24,8 @@ struct ggml_mpi_context {
|
|||
MPI_Comm comm;
|
||||
int layer_start;
|
||||
int layer_end;
|
||||
MPI_Status status;
|
||||
|
||||
struct ggml_tensor *inp0;
|
||||
std::string name;
|
||||
struct ggml_backend * wrapped_backend;
|
||||
|
@ -31,6 +33,8 @@ struct ggml_mpi_context {
|
|||
ggml_backend_sched_t scheduler;
|
||||
bool remote;
|
||||
void* send_buffer;
|
||||
int trans_id;
|
||||
int recv_trans_id;
|
||||
};
|
||||
|
||||
void ggml_mpi_backend_init(void) {
|
||||
|
@ -122,12 +126,43 @@ void ggml_mpi_sync_pipelined(
|
|||
}
|
||||
if(ctx_mpi->rank < ctx_mpi->size - 1) {
|
||||
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
|
||||
GGML_ASSERT(val != nullptr);
|
||||
GGML_ASSERT(count < 128*1024*1024);
|
||||
|
||||
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
|
||||
GGML_ASSERT(retval == MPI_SUCCESS);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_mpi_barrier(struct ggml_mpi_context * ctx_mpi) {
|
||||
MPI_Barrier(ctx_mpi->comm);
|
||||
}
|
||||
|
||||
void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag) {
|
||||
MPI_Probe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &(ctx_mpi->status));
|
||||
}
|
||||
|
||||
int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag) {
|
||||
if(ctx_mpi->comm == MPI_COMM_NULL) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ret;
|
||||
MPI_Iprobe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &ret, &(ctx_mpi->status));
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi) {
|
||||
return ctx_mpi->status.MPI_TAG;
|
||||
}
|
||||
|
||||
int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi) {
|
||||
int32_t count;
|
||||
MPI_Get_count(&ctx_mpi->status, MPI_INT32_T, &count);
|
||||
return count;
|
||||
}
|
||||
|
||||
void ggml_mpi_eval_init(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * n_tokens,
|
||||
|
@ -142,8 +177,15 @@ void ggml_mpi_eval_init(
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
int32_t old_n_tokens = *n_tokens;
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS);
|
||||
|
||||
if (old_n_tokens != *n_tokens) {
|
||||
*pos = static_cast<int32_t *>(realloc(*pos, *n_tokens * sizeof(int32_t)));
|
||||
*n_seq_ids = static_cast<int32_t *>(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t)));
|
||||
*logits = static_cast<int8_t *>(realloc(*logits, *n_tokens * sizeof(int32_t)));
|
||||
}
|
||||
|
||||
int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t));
|
||||
|
||||
if (ctx_mpi->rank == 0 && *logits != nullptr) {
|
||||
|
@ -183,49 +225,51 @@ void ggml_mpi_eval_init(
|
|||
// pre-allocated for the largest possible sizes, even on worker nodes.
|
||||
|
||||
GGML_ASSERT(n_seq_ids != nullptr);
|
||||
GGML_ASSERT(*n_seq_ids != nullptr);
|
||||
|
||||
GGML_ASSERT(n_tokens != nullptr);
|
||||
|
||||
|
||||
// FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend
|
||||
// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
|
||||
|
||||
// We need to know the total number of sequence
|
||||
// ids, so we count them all up
|
||||
// int32_t total_n_seq_ids = 0;
|
||||
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
// total_n_seq_ids += (*n_seq_ids)[i];
|
||||
// }
|
||||
//
|
||||
// // MPI can't chase the pointers for multidimensional arrays, so we flatten them first
|
||||
// // for transit
|
||||
// int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
|
||||
//
|
||||
// int32_t current_index = 0;
|
||||
//
|
||||
// // Only rank 0 needs to flatten since the others don't have the real seq_id
|
||||
// if (ctx_mpi->rank == 0) {
|
||||
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
// flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
||||
// current_index++;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
//
|
||||
//
|
||||
// ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
|
||||
// ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
|
||||
//
|
||||
// current_index = 0;
|
||||
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
// (*seq_id)[i][j] = flattened_seq_ids[current_index];
|
||||
// current_index++;
|
||||
// }
|
||||
//
|
||||
// }
|
||||
// free(flattened_seq_ids);
|
||||
int32_t total_n_seq_ids = 0;
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
total_n_seq_ids += (*n_seq_ids)[i];
|
||||
}
|
||||
|
||||
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first
|
||||
// for transit
|
||||
int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
|
||||
|
||||
int32_t current_index = 0;
|
||||
|
||||
// Only rank 0 needs to flatten since the others don't have the real seq_id
|
||||
if (ctx_mpi->rank == 0) {
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
||||
current_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
|
||||
|
||||
current_index = 0;
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
(*seq_id)[i][j] = flattened_seq_ids[current_index];
|
||||
current_index++;
|
||||
}
|
||||
|
||||
}
|
||||
free(flattened_seq_ids);
|
||||
}
|
||||
|
||||
|
||||
|
@ -236,6 +280,19 @@ void ggml_mpi_sync_int(
|
|||
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
}
|
||||
|
||||
void ggml_mpi_sync_ints_pipelined(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * vals,
|
||||
int count,
|
||||
int tag
|
||||
) {
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, vals, count, MPI_INT32_T, tag);
|
||||
int old_trans = ctx_mpi->trans_id;
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, &ctx_mpi->trans_id, 1, MPI_INT32_T, GGML_MPI_TRANS_ID);
|
||||
ctx_mpi->recv_trans_id = ctx_mpi->trans_id;
|
||||
ctx_mpi->trans_id = old_trans;
|
||||
}
|
||||
|
||||
static void ggml_mpi_tensor_send(const struct ggml_tensor * t, const void* data, int mpi_rank_dst, MPI_Comm comm) {
|
||||
MPI_Datatype mpi_type;
|
||||
|
||||
|
@ -549,6 +606,8 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
|
|||
}
|
||||
}
|
||||
|
||||
// TODO exploding memory usage cause we replace the buffer with the wrapped buffer,
|
||||
// but don't free the contexts, and then create new ones when we re-wrap
|
||||
|
||||
|
||||
if (!ctx->remote) {
|
||||
|
|
20
ggml-mpi.h
20
ggml-mpi.h
|
@ -100,6 +100,26 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(
|
|||
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf);
|
||||
|
||||
|
||||
void ggml_mpi_sync_ints_pipelined(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * vals,
|
||||
int count,
|
||||
int tag
|
||||
);
|
||||
|
||||
void ggml_mpi_sync_ints_pipelined_back(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * vals,
|
||||
int count,
|
||||
int tag
|
||||
);
|
||||
// clear = 1, rm = 2, cp = 3, keep = 4, seq_shift = 5
|
||||
void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
|
||||
int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi);
|
||||
|
||||
int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
|
||||
int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi);
|
||||
|
||||
/**
|
||||
* Create a new context by splitting the given context's
|
||||
* communicator, creating a "sub-communicator." This is a collective
|
||||
|
|
138
llama.cpp
138
llama.cpp
|
@ -9064,11 +9064,58 @@ static void llama_graph_compute(
|
|||
//
|
||||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch_all) { // TODO: rename back to batch
|
||||
llama_batch & batch_all) { // TODO: rename back to batch
|
||||
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(lctx.model.ctx_mpi) == 0 && ggml_mpi_size(lctx.model.ctx_mpi) > 1) {
|
||||
int transaction_type = GGML_MPI_DECODE;
|
||||
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
}
|
||||
// ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.batch_id, 1, GGML_MPI_BATCH_ID);
|
||||
int old_tokens = batch_all.n_tokens;
|
||||
|
||||
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.n_tokens, 1, GGML_MPI_N_TOKENS);
|
||||
|
||||
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, reinterpret_cast<int32_t *>(&lctx.cparams.n_seq_max), 1, GGML_MPI_MAX_N_SEQ);
|
||||
if (ggml_mpi_rank(lctx.model.ctx_mpi) > 0) {
|
||||
int new_n_tokens = batch_all.n_tokens;
|
||||
llama_batch_free(batch_all);
|
||||
batch_all = llama_batch_init(new_n_tokens, 0, (int32_t)lctx.cparams.n_seq_max);
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t n_tokens_all = batch_all.n_tokens;
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id_arr;
|
||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
|
||||
if (batch_all.pos == nullptr) {
|
||||
pos.resize(n_tokens_all);
|
||||
for (uint32_t i = 0; i < n_tokens_all; i++) {
|
||||
pos[i] = batch_all.all_pos_0 + i*batch_all.all_pos_1;
|
||||
}
|
||||
|
||||
batch_all.pos = pos.data();
|
||||
}
|
||||
|
||||
if (batch_all.seq_id == nullptr) {
|
||||
n_seq_id.resize(n_tokens_all);
|
||||
seq_id.resize(n_tokens_all);
|
||||
seq_id_arr.resize(n_tokens_all);
|
||||
for (uint32_t i = 0; i < n_tokens_all; i++) {
|
||||
n_seq_id[i] = 1;
|
||||
seq_id[i].resize(lctx.cparams.n_seq_max);
|
||||
seq_id[i][0] = batch_all.all_seq_id;
|
||||
seq_id_arr[i] = seq_id[i].data();
|
||||
}
|
||||
|
||||
batch_all.n_seq_id = n_seq_id.data();
|
||||
batch_all.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_eval_init(lctx.model.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max);
|
||||
n_tokens_all = batch_all.n_tokens;
|
||||
|
@ -9114,10 +9161,6 @@ static int llama_decode_internal(
|
|||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id_arr;
|
||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
||||
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
|
@ -14344,6 +14387,87 @@ void llama_batch_free(struct llama_batch batch) {
|
|||
batch.logits = nullptr;
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
|
||||
int llama_process_mpi_transaction(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch & batch,
|
||||
int tag) {
|
||||
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
|
||||
// printf("\nBeginning transaction type %d\n", tag);
|
||||
// }
|
||||
|
||||
switch (tag) {
|
||||
case GGML_MPI_DECODE:
|
||||
// llama_batch_free(batch);
|
||||
return llama_decode_internal(*ctx, batch);
|
||||
break;
|
||||
case GGML_MPI_KV_CLEAR:
|
||||
llama_kv_cache_clear(ctx);
|
||||
break;
|
||||
case GGML_MPI_KV_SEQ_RM:
|
||||
llama_kv_cache_seq_rm(ctx, 1, -1, -1);
|
||||
break;
|
||||
case GGML_MPI_KV_SEQ_CP:
|
||||
llama_kv_cache_seq_cp(ctx, 0, 0, 0, 0);
|
||||
break;
|
||||
// case GGML_MPI_KV_SEQ_CP_BACK:
|
||||
// llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0);
|
||||
// break;
|
||||
// case GGML_MPI_KV_SEQ_KEEP:
|
||||
// llama_kv_cache_seq_keep(ctx, 0);
|
||||
// break;
|
||||
// case GGML_MPI_KV_SEQ_SHIFT:
|
||||
// llama_kv_cache_seq_shift(ctx, 0, 0, 0, 0);
|
||||
// break;
|
||||
default:
|
||||
printf("Unknown operation, exiting\n");
|
||||
exit(1);
|
||||
break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int llama_process_mpi_worker(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch & batch) {
|
||||
ggml_mpi_probe(ctx->model.ctx_mpi, -1, -1);
|
||||
int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi);
|
||||
int32_t count;
|
||||
int32_t trans_type;
|
||||
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
|
||||
// printf("\nReceived command %d\n", tag);
|
||||
// }
|
||||
switch (tag) {
|
||||
case GGML_MPI_BEGIN_TRANSACTION:
|
||||
|
||||
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &trans_type, 1, GGML_MPI_BEGIN_TRANSACTION);
|
||||
return llama_process_mpi_transaction(ctx, batch, trans_type);
|
||||
break;
|
||||
case GGML_MPI_SHUTDOWN:
|
||||
llama_free(ctx);
|
||||
llama_backend_free();
|
||||
exit(0);
|
||||
break;
|
||||
case GGML_MPI_CANCEL_RUN:
|
||||
// count = ggml_mpi_status_count_int32(ctx->model.ctx_mpi);
|
||||
//// printf("Received cancel run\n");
|
||||
// {
|
||||
// std::vector<int32_t> canceled(count, -1);
|
||||
// llama_cancel_run(ctx, canceled.data(), canceled.size());
|
||||
//
|
||||
// }
|
||||
// break;
|
||||
default:
|
||||
printf("Unknown operation, exiting\n");
|
||||
exit(1);
|
||||
break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
|
@ -14351,9 +14475,7 @@ int32_t llama_decode(
|
|||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->model.ctx_mpi) > 0) {
|
||||
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
std::vector<llama_token> tmp(n_ctx, llama_token_bos(&ctx->model));
|
||||
while (llama_decode_internal(*ctx, batch) >= 0){};
|
||||
while (llama_process_mpi_worker(ctx, batch) >= 0){};
|
||||
llama_backend_free();
|
||||
exit(1);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue