Resize seq_ids by n_seq_max, port over sync_pipelined instead of using Bcast

This commit is contained in:
Branden Butler 2024-03-12 14:20:03 -05:00
parent c6280bc3f4
commit 72dcd66c0f
3 changed files with 207 additions and 61 deletions

View file

@ -14,6 +14,10 @@
#define UNUSED GGML_UNUSED
static bool have_init = false;
static void* send_buffer;
struct ggml_mpi_context {
int rank;
int size;
@ -26,18 +30,37 @@ struct ggml_mpi_context {
std::vector<ggml_backend_t> backends;
ggml_backend_sched_t scheduler;
bool remote;
void* send_buffer;
};
void ggml_mpi_backend_init(void) {
int ret;
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &ret);
GGML_ASSERT(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ret) == MPI_SUCCESS);
have_init = true;
const int buffer_size = 128*1024*1024*8;
send_buffer = calloc(1, buffer_size); // 128MB buffer
// fprintf(stderr, "BUFFER ATTACH RETCODE=%d\n", MPI_Buffer_attach(send_buffer, buffer_size));
}
void ggml_mpi_sync_pipelined(
struct ggml_mpi_context * ctx_mpi,
void * val,
int count,
MPI_Datatype datatype,
int tag
);
void ggml_mpi_backend_free(void) {
MPI_Finalize();
}
struct ggml_mpi_context * ggml_mpi_init(void) {
if (!have_init) {
ggml_mpi_backend_init();
}
auto * ctx = new ggml_mpi_context;
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
@ -45,6 +68,8 @@ struct ggml_mpi_context * ggml_mpi_init(void) {
ctx->comm = MPI_COMM_WORLD;
ctx->remote = false;
ctx->send_buffer = send_buffer;
return ctx;
}
@ -69,77 +94,146 @@ size_t ggml_mpi_size(struct ggml_mpi_context * ctx) {
return ctx->size;
}
int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) {
return (ctx_mpi->rank + 1) % ctx_mpi->size;
}
int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi) {
int temp = (ctx_mpi->rank - 1);
return (temp >= 0) ? temp : ctx_mpi->size - 1;
}
void ggml_mpi_sync_pipelined(
struct ggml_mpi_context * ctx_mpi,
void * val,
int count,
MPI_Datatype datatype,
int tag
) {
if(ctx_mpi->comm == MPI_COMM_NULL) {
return;
}
// printf("Rank %d sync pipelined with tag %d\n", ctx_mpi->rank, tag);
if (ctx_mpi->rank != 0) {
MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE);
}
if(ctx_mpi->rank < ctx_mpi->size - 1) {
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
GGML_ASSERT(retval == MPI_SUCCESS);
}
}
void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens,
int32_t ** pos,
int32_t ** n_seq_ids,
int32_t *** seq_id,
int8_t ** logits) {
int8_t ** logits,
uint32_t n_seq_max) {
// fprintf(stderr, "Beginning eval init on rank %d\n", ctx_mpi->rank);
MPI_Barrier(ctx_mpi->comm);
if(ctx_mpi->comm == MPI_COMM_NULL) {
return;
}
int32_t old_n_tokens = *n_tokens;
MPI_Bcast(n_tokens, 1, MPI_INT32_T, 0, ctx_mpi->comm);
// fprintf(stderr, "Node %d, old_n_tokens: %d, new n_tokens: %d\n", ctx_mpi->rank, old_n_tokens, *n_tokens);
// If what was passed in differs from what was broadcast,
// we can't guarantee the allocated sizes are correct
// TODO check how often this is done and if it's a problem,
// try to allocate ahead of time
if (old_n_tokens != *n_tokens) {
*pos = static_cast<int32_t *>(realloc(*pos, *n_tokens * sizeof(int32_t)));
*n_seq_ids = static_cast<int32_t *>(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t)));
*logits = static_cast<int8_t *>(realloc(*logits, *n_tokens * sizeof(int32_t)));
ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS);
int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t));
if (ctx_mpi->rank == 0 && *logits != nullptr) {
ggml_mpi_sync_pipelined(ctx_mpi, *logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
} else {
ggml_mpi_sync_pipelined(ctx_mpi, temp_logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
}
// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
if (ctx_mpi->rank != 0) {
bool should_set_batch_logits = false;
for (int i = 0; i < *n_tokens; i++) {
if (temp_logits[i]) {
should_set_batch_logits = true;
break;
}
}
if (should_set_batch_logits) {
if (*logits != NULL) {
free(*logits);
*logits = NULL;
}
*logits = temp_logits;
} else {
if (*logits != NULL) {
free(*logits);
*logits = NULL;
}
free(temp_logits);
}
} else {
free(temp_logits);
}
// For now, we assume that the pos, seq_ids, tokens, etc have been
// pre-allocated for the largest possible sizes, even on worker nodes.
//if (old_n_tokens != *n_tokens) {
// *pos = realloc(*pos, *n_tokens * sizeof(int32_t));
// *n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
// *tokens = realloc(*tokens, *n_tokens * sizeof(int32_t ));
//}
GGML_ASSERT(n_seq_ids != nullptr);
GGML_ASSERT(n_tokens != nullptr);
// FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend
// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
// We need to know the total number of sequence
// ids, so we count them all up
int32_t total_n_seq_ids = 0;
for (int32_t i = 0; i < *n_tokens; i++) {
total_n_seq_ids += (*n_seq_ids)[i];
// int32_t total_n_seq_ids = 0;
// for (int32_t i = 0; i < *n_tokens; i++) {
// total_n_seq_ids += (*n_seq_ids)[i];
// }
//
// // MPI can't chase the pointers for multidimensional arrays, so we flatten them first
// // for transit
// int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
//
// int32_t current_index = 0;
//
// // Only rank 0 needs to flatten since the others don't have the real seq_id
// if (ctx_mpi->rank == 0) {
// for (int32_t i = 0; i < *n_tokens; i++) {
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
// flattened_seq_ids[current_index] = (*seq_id)[i][j];
// current_index++;
// }
// }
// }
//
//
//
// ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
// ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
//
// current_index = 0;
// for (int32_t i = 0; i < *n_tokens; i++) {
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
// (*seq_id)[i][j] = flattened_seq_ids[current_index];
// current_index++;
// }
//
// }
// free(flattened_seq_ids);
}
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first
// for transit
auto * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
int32_t current_index = 0;
// Only rank 0 needs to flatten since the others don't have the real seq_id
if (ctx_mpi->rank == 0) {
for (int32_t i = 0; i < *n_tokens; i++) {
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
flattened_seq_ids[current_index] = (*seq_id)[i][j];
current_index++;
}
}
}
MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
auto ** new_seq_id = static_cast<int32_t **>(calloc(*n_tokens, sizeof(int32_t *)));
current_index = 0;
for (int32_t i = 0; i < *n_tokens; i++) {
new_seq_id[i] = static_cast<int32_t *>(calloc((*n_seq_ids)[i], sizeof(int32_t)));
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
new_seq_id[i][j] = flattened_seq_ids[current_index];
current_index++;
}
}
free(flattened_seq_ids);
//free(*seq_id); // <- something is still holding onto this, need to investigate
*seq_id = new_seq_id;
}
void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi,

View file

@ -12,6 +12,50 @@ struct ggml_cgraph;
extern "C" {
#endif
#define GGML_MPI_DECODE 0
#define GGML_MPI_KV_CLEAR 1
#define GGML_MPI_KV_SEQ_RM 2
#define GGML_MPI_KV_SEQ_CP 3
#define GGML_MPI_KV_SEQ_KEEP 4
#define GGML_MPI_KV_SEQ_SHIFT 5
#define GGML_MPI_SHUTDOWN 6
#define GGML_MPI_TRANSFER_TENSORS 7
#define GGML_MPI_SYNC_LOGITS 8
#define GGML_MPI_CANCEL_RUN 9
#define GGML_MPI_KV_SEQ_CP_BACK 10
#define GGML_MPI_TRANS_ID 11
#define GGML_MPI_BATCH_ID 12
#define GGML_MPI_N_TOKENS 13
#define GGML_MPI_TOKENS 14
#define GGML_MPI_N_SEQ_IDS 15
#define GGML_MPI_SEQ_IDS 16
#define GGML_MPI_POS 17
#define GGML_MPI_BEGIN_TRANSACTION 18
#define GGML_MPI_MAX_N_SEQ 19
#define GGML_MPI_BATCH_LOGITS 20
/**
* The context used for MPI operations,
* a program may make use of more than one
@ -131,7 +175,8 @@ void ggml_mpi_eval_init(
int32_t ** pos,
int32_t ** n_seq_ids,
int32_t *** seq_id,
int8_t ** logits);
int8_t ** logits,
uint32_t n_seq_max);
void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi,

View file

@ -1750,6 +1750,7 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
uint32_t n_seq_max;
};
struct llama_layer {
@ -8812,7 +8813,7 @@ static int llama_decode_internal(
uint32_t n_tokens_all = batch_all.n_tokens;
#ifdef GGML_USE_MPI
ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits));
ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max);
n_tokens_all = batch_all.n_tokens;
#endif
@ -8896,7 +8897,7 @@ static int llama_decode_internal(
seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
n_seq_id[i] = 1;
seq_id[i].resize(1);
seq_id[i].resize(lctx.cparams.n_seq_max);
seq_id[i][0] = u_batch.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
@ -12753,6 +12754,9 @@ void llama_backend_init(void) {
ggml_free(ctx);
}
#ifdef GGML_USE_MPI
ggml_mpi_backend_init();
#endif
}
@ -12760,10 +12764,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
if (numa != GGML_NUMA_STRATEGY_DISABLED) {
ggml_numa_init(numa);
}
#ifdef GGML_USE_MPI
ggml_mpi_backend_init();
#endif
}
void llama_backend_free(void) {
@ -12844,7 +12844,7 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
// TODO: maybe add n_seq_max here too
cparams.n_seq_max = params.n_seq_max;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -13984,6 +13984,13 @@ void llama_batch_free(struct llama_batch batch) {
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
batch.token = nullptr;
batch.embd = nullptr;
batch.pos = nullptr;
batch.n_seq_id = nullptr;
batch.seq_id = nullptr;
batch.logits = nullptr;
}
int32_t llama_decode(