Resize seq_ids by n_seq_max, port over sync_pipelined instead of using Bcast

This commit is contained in:
Branden Butler 2024-03-12 14:20:03 -05:00
parent c6280bc3f4
commit 72dcd66c0f
3 changed files with 207 additions and 61 deletions

View file

@ -14,6 +14,10 @@
#define UNUSED GGML_UNUSED #define UNUSED GGML_UNUSED
static bool have_init = false;
static void* send_buffer;
struct ggml_mpi_context { struct ggml_mpi_context {
int rank; int rank;
int size; int size;
@ -26,18 +30,37 @@ struct ggml_mpi_context {
std::vector<ggml_backend_t> backends; std::vector<ggml_backend_t> backends;
ggml_backend_sched_t scheduler; ggml_backend_sched_t scheduler;
bool remote; bool remote;
void* send_buffer;
}; };
void ggml_mpi_backend_init(void) { void ggml_mpi_backend_init(void) {
int ret; int ret;
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &ret);
GGML_ASSERT(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ret) == MPI_SUCCESS);
have_init = true;
const int buffer_size = 128*1024*1024*8;
send_buffer = calloc(1, buffer_size); // 128MB buffer
// fprintf(stderr, "BUFFER ATTACH RETCODE=%d\n", MPI_Buffer_attach(send_buffer, buffer_size));
} }
void ggml_mpi_sync_pipelined(
struct ggml_mpi_context * ctx_mpi,
void * val,
int count,
MPI_Datatype datatype,
int tag
);
void ggml_mpi_backend_free(void) { void ggml_mpi_backend_free(void) {
MPI_Finalize(); MPI_Finalize();
} }
struct ggml_mpi_context * ggml_mpi_init(void) { struct ggml_mpi_context * ggml_mpi_init(void) {
if (!have_init) {
ggml_mpi_backend_init();
}
auto * ctx = new ggml_mpi_context; auto * ctx = new ggml_mpi_context;
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
@ -45,6 +68,8 @@ struct ggml_mpi_context * ggml_mpi_init(void) {
ctx->comm = MPI_COMM_WORLD; ctx->comm = MPI_COMM_WORLD;
ctx->remote = false; ctx->remote = false;
ctx->send_buffer = send_buffer;
return ctx; return ctx;
} }
@ -69,78 +94,147 @@ size_t ggml_mpi_size(struct ggml_mpi_context * ctx) {
return ctx->size; return ctx->size;
} }
int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) {
return (ctx_mpi->rank + 1) % ctx_mpi->size;
}
int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi) {
int temp = (ctx_mpi->rank - 1);
return (temp >= 0) ? temp : ctx_mpi->size - 1;
}
void ggml_mpi_sync_pipelined(
struct ggml_mpi_context * ctx_mpi,
void * val,
int count,
MPI_Datatype datatype,
int tag
) {
if(ctx_mpi->comm == MPI_COMM_NULL) {
return;
}
// printf("Rank %d sync pipelined with tag %d\n", ctx_mpi->rank, tag);
if (ctx_mpi->rank != 0) {
MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE);
}
if(ctx_mpi->rank < ctx_mpi->size - 1) {
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
GGML_ASSERT(retval == MPI_SUCCESS);
}
}
void ggml_mpi_eval_init( void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens, int32_t * n_tokens,
int32_t ** pos, int32_t ** pos,
int32_t ** n_seq_ids, int32_t ** n_seq_ids,
int32_t *** seq_id, int32_t *** seq_id,
int8_t ** logits) { int8_t ** logits,
uint32_t n_seq_max) {
// fprintf(stderr, "Beginning eval init on rank %d\n", ctx_mpi->rank); if(ctx_mpi->comm == MPI_COMM_NULL) {
MPI_Barrier(ctx_mpi->comm); return;
}
int32_t old_n_tokens = *n_tokens; int32_t old_n_tokens = *n_tokens;
MPI_Bcast(n_tokens, 1, MPI_INT32_T, 0, ctx_mpi->comm);
// fprintf(stderr, "Node %d, old_n_tokens: %d, new n_tokens: %d\n", ctx_mpi->rank, old_n_tokens, *n_tokens);
// If what was passed in differs from what was broadcast, ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS);
// we can't guarantee the allocated sizes are correct int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t));
// TODO check how often this is done and if it's a problem,
// try to allocate ahead of time if (ctx_mpi->rank == 0 && *logits != nullptr) {
if (old_n_tokens != *n_tokens) { ggml_mpi_sync_pipelined(ctx_mpi, *logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
*pos = static_cast<int32_t *>(realloc(*pos, *n_tokens * sizeof(int32_t))); } else {
*n_seq_ids = static_cast<int32_t *>(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t))); ggml_mpi_sync_pipelined(ctx_mpi, temp_logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
*logits = static_cast<int8_t *>(realloc(*logits, *n_tokens * sizeof(int32_t)));
} }
// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm); if (ctx_mpi->rank != 0) {
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); bool should_set_batch_logits = false;
for (int i = 0; i < *n_tokens; i++) {
if (temp_logits[i]) {
should_set_batch_logits = true;
break;
}
}
if (should_set_batch_logits) {
if (*logits != NULL) {
free(*logits);
*logits = NULL;
}
*logits = temp_logits;
} else {
if (*logits != NULL) {
free(*logits);
*logits = NULL;
}
free(temp_logits);
}
} else {
free(temp_logits);
}
// For now, we assume that the pos, seq_ids, tokens, etc have been
// pre-allocated for the largest possible sizes, even on worker nodes.
//if (old_n_tokens != *n_tokens) {
// *pos = realloc(*pos, *n_tokens * sizeof(int32_t));
// *n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
// *tokens = realloc(*tokens, *n_tokens * sizeof(int32_t ));
//}
GGML_ASSERT(n_seq_ids != nullptr);
GGML_ASSERT(n_tokens != nullptr);
// FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend
// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
// We need to know the total number of sequence // We need to know the total number of sequence
// ids, so we count them all up // ids, so we count them all up
int32_t total_n_seq_ids = 0; // int32_t total_n_seq_ids = 0;
for (int32_t i = 0; i < *n_tokens; i++) { // for (int32_t i = 0; i < *n_tokens; i++) {
total_n_seq_ids += (*n_seq_ids)[i]; // total_n_seq_ids += (*n_seq_ids)[i];
} // }
//
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first // // MPI can't chase the pointers for multidimensional arrays, so we flatten them first
// for transit // // for transit
auto * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t))); // int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
//
int32_t current_index = 0; // int32_t current_index = 0;
//
// Only rank 0 needs to flatten since the others don't have the real seq_id // // Only rank 0 needs to flatten since the others don't have the real seq_id
if (ctx_mpi->rank == 0) { // if (ctx_mpi->rank == 0) {
for (int32_t i = 0; i < *n_tokens; i++) { // for (int32_t i = 0; i < *n_tokens; i++) {
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { // for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
flattened_seq_ids[current_index] = (*seq_id)[i][j]; // flattened_seq_ids[current_index] = (*seq_id)[i][j];
current_index++; // current_index++;
} // }
} // }
} // }
//
//
MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm); //
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm); // ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm); // ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
auto ** new_seq_id = static_cast<int32_t **>(calloc(*n_tokens, sizeof(int32_t *))); //
current_index = 0; // current_index = 0;
for (int32_t i = 0; i < *n_tokens; i++) { // for (int32_t i = 0; i < *n_tokens; i++) {
new_seq_id[i] = static_cast<int32_t *>(calloc((*n_seq_ids)[i], sizeof(int32_t))); // for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { // (*seq_id)[i][j] = flattened_seq_ids[current_index];
new_seq_id[i][j] = flattened_seq_ids[current_index]; // current_index++;
current_index++; // }
} //
} // }
free(flattened_seq_ids); // free(flattened_seq_ids);
//free(*seq_id); // <- something is still holding onto this, need to investigate
*seq_id = new_seq_id;
} }
void ggml_mpi_synch_int( void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,
int32_t * val int32_t * val

View file

@ -12,6 +12,50 @@ struct ggml_cgraph;
extern "C" { extern "C" {
#endif #endif
#define GGML_MPI_DECODE 0
#define GGML_MPI_KV_CLEAR 1
#define GGML_MPI_KV_SEQ_RM 2
#define GGML_MPI_KV_SEQ_CP 3
#define GGML_MPI_KV_SEQ_KEEP 4
#define GGML_MPI_KV_SEQ_SHIFT 5
#define GGML_MPI_SHUTDOWN 6
#define GGML_MPI_TRANSFER_TENSORS 7
#define GGML_MPI_SYNC_LOGITS 8
#define GGML_MPI_CANCEL_RUN 9
#define GGML_MPI_KV_SEQ_CP_BACK 10
#define GGML_MPI_TRANS_ID 11
#define GGML_MPI_BATCH_ID 12
#define GGML_MPI_N_TOKENS 13
#define GGML_MPI_TOKENS 14
#define GGML_MPI_N_SEQ_IDS 15
#define GGML_MPI_SEQ_IDS 16
#define GGML_MPI_POS 17
#define GGML_MPI_BEGIN_TRANSACTION 18
#define GGML_MPI_MAX_N_SEQ 19
#define GGML_MPI_BATCH_LOGITS 20
/** /**
* The context used for MPI operations, * The context used for MPI operations,
* a program may make use of more than one * a program may make use of more than one
@ -131,7 +175,8 @@ void ggml_mpi_eval_init(
int32_t ** pos, int32_t ** pos,
int32_t ** n_seq_ids, int32_t ** n_seq_ids,
int32_t *** seq_id, int32_t *** seq_id,
int8_t ** logits); int8_t ** logits,
uint32_t n_seq_max);
void ggml_mpi_synch_int( void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,

View file

@ -1750,6 +1750,7 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data; void * cb_eval_user_data;
uint32_t n_seq_max;
}; };
struct llama_layer { struct llama_layer {
@ -8812,7 +8813,7 @@ static int llama_decode_internal(
uint32_t n_tokens_all = batch_all.n_tokens; uint32_t n_tokens_all = batch_all.n_tokens;
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits)); ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max);
n_tokens_all = batch_all.n_tokens; n_tokens_all = batch_all.n_tokens;
#endif #endif
@ -8896,7 +8897,7 @@ static int llama_decode_internal(
seq_id_arr.resize(n_tokens); seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
n_seq_id[i] = 1; n_seq_id[i] = 1;
seq_id[i].resize(1); seq_id[i].resize(lctx.cparams.n_seq_max);
seq_id[i][0] = u_batch.all_seq_id; seq_id[i][0] = u_batch.all_seq_id;
seq_id_arr[i] = seq_id[i].data(); seq_id_arr[i] = seq_id[i].data();
} }
@ -12753,6 +12754,9 @@ void llama_backend_init(void) {
ggml_free(ctx); ggml_free(ctx);
} }
#ifdef GGML_USE_MPI
ggml_mpi_backend_init();
#endif
} }
@ -12760,10 +12764,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
if (numa != GGML_NUMA_STRATEGY_DISABLED) { if (numa != GGML_NUMA_STRATEGY_DISABLED) {
ggml_numa_init(numa); ggml_numa_init(numa);
} }
#ifdef GGML_USE_MPI
ggml_mpi_backend_init();
#endif
} }
void llama_backend_free(void) { void llama_backend_free(void) {
@ -12844,7 +12844,7 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams; const auto & hparams = model->hparams;
auto & cparams = ctx->cparams; auto & cparams = ctx->cparams;
// TODO: maybe add n_seq_max here too cparams.n_seq_max = params.n_seq_max;
cparams.n_threads = params.n_threads; cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch; cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor; cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -13984,6 +13984,13 @@ void llama_batch_free(struct llama_batch batch) {
free(batch.seq_id); free(batch.seq_id);
} }
if (batch.logits) free(batch.logits); if (batch.logits) free(batch.logits);
batch.token = nullptr;
batch.embd = nullptr;
batch.pos = nullptr;
batch.n_seq_id = nullptr;
batch.seq_id = nullptr;
batch.logits = nullptr;
} }
int32_t llama_decode( int32_t llama_decode(