Resize seq_ids by n_seq_max, port over sync_pipelined instead of using Bcast
This commit is contained in:
parent
c6280bc3f4
commit
72dcd66c0f
3 changed files with 207 additions and 61 deletions
198
ggml-mpi.cpp
198
ggml-mpi.cpp
|
@ -14,6 +14,10 @@
|
|||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
static bool have_init = false;
|
||||
|
||||
static void* send_buffer;
|
||||
|
||||
struct ggml_mpi_context {
|
||||
int rank;
|
||||
int size;
|
||||
|
@ -26,18 +30,37 @@ struct ggml_mpi_context {
|
|||
std::vector<ggml_backend_t> backends;
|
||||
ggml_backend_sched_t scheduler;
|
||||
bool remote;
|
||||
void* send_buffer;
|
||||
};
|
||||
|
||||
void ggml_mpi_backend_init(void) {
|
||||
int ret;
|
||||
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &ret);
|
||||
|
||||
GGML_ASSERT(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ret) == MPI_SUCCESS);
|
||||
have_init = true;
|
||||
const int buffer_size = 128*1024*1024*8;
|
||||
send_buffer = calloc(1, buffer_size); // 128MB buffer
|
||||
// fprintf(stderr, "BUFFER ATTACH RETCODE=%d\n", MPI_Buffer_attach(send_buffer, buffer_size));
|
||||
}
|
||||
|
||||
void ggml_mpi_sync_pipelined(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
void * val,
|
||||
int count,
|
||||
MPI_Datatype datatype,
|
||||
int tag
|
||||
);
|
||||
|
||||
void ggml_mpi_backend_free(void) {
|
||||
MPI_Finalize();
|
||||
}
|
||||
|
||||
struct ggml_mpi_context * ggml_mpi_init(void) {
|
||||
|
||||
if (!have_init) {
|
||||
ggml_mpi_backend_init();
|
||||
}
|
||||
|
||||
auto * ctx = new ggml_mpi_context;
|
||||
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
|
||||
|
@ -45,6 +68,8 @@ struct ggml_mpi_context * ggml_mpi_init(void) {
|
|||
ctx->comm = MPI_COMM_WORLD;
|
||||
ctx->remote = false;
|
||||
|
||||
ctx->send_buffer = send_buffer;
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
@ -69,77 +94,146 @@ size_t ggml_mpi_size(struct ggml_mpi_context * ctx) {
|
|||
return ctx->size;
|
||||
}
|
||||
|
||||
int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) {
|
||||
return (ctx_mpi->rank + 1) % ctx_mpi->size;
|
||||
}
|
||||
|
||||
int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi) {
|
||||
int temp = (ctx_mpi->rank - 1);
|
||||
return (temp >= 0) ? temp : ctx_mpi->size - 1;
|
||||
}
|
||||
|
||||
void ggml_mpi_sync_pipelined(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
void * val,
|
||||
int count,
|
||||
MPI_Datatype datatype,
|
||||
int tag
|
||||
) {
|
||||
if(ctx_mpi->comm == MPI_COMM_NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
// printf("Rank %d sync pipelined with tag %d\n", ctx_mpi->rank, tag);
|
||||
|
||||
|
||||
if (ctx_mpi->rank != 0) {
|
||||
MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE);
|
||||
}
|
||||
if(ctx_mpi->rank < ctx_mpi->size - 1) {
|
||||
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
|
||||
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
|
||||
GGML_ASSERT(retval == MPI_SUCCESS);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_mpi_eval_init(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * n_tokens,
|
||||
int32_t ** pos,
|
||||
int32_t ** n_seq_ids,
|
||||
int32_t *** seq_id,
|
||||
int8_t ** logits) {
|
||||
int8_t ** logits,
|
||||
uint32_t n_seq_max) {
|
||||
|
||||
|
||||
// fprintf(stderr, "Beginning eval init on rank %d\n", ctx_mpi->rank);
|
||||
MPI_Barrier(ctx_mpi->comm);
|
||||
if(ctx_mpi->comm == MPI_COMM_NULL) {
|
||||
return;
|
||||
}
|
||||
int32_t old_n_tokens = *n_tokens;
|
||||
MPI_Bcast(n_tokens, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
|
||||
// fprintf(stderr, "Node %d, old_n_tokens: %d, new n_tokens: %d\n", ctx_mpi->rank, old_n_tokens, *n_tokens);
|
||||
|
||||
// If what was passed in differs from what was broadcast,
|
||||
// we can't guarantee the allocated sizes are correct
|
||||
// TODO check how often this is done and if it's a problem,
|
||||
// try to allocate ahead of time
|
||||
if (old_n_tokens != *n_tokens) {
|
||||
*pos = static_cast<int32_t *>(realloc(*pos, *n_tokens * sizeof(int32_t)));
|
||||
*n_seq_ids = static_cast<int32_t *>(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t)));
|
||||
*logits = static_cast<int8_t *>(realloc(*logits, *n_tokens * sizeof(int32_t)));
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS);
|
||||
int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t));
|
||||
|
||||
if (ctx_mpi->rank == 0 && *logits != nullptr) {
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, *logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
|
||||
} else {
|
||||
ggml_mpi_sync_pipelined(ctx_mpi, temp_logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
if (ctx_mpi->rank != 0) {
|
||||
bool should_set_batch_logits = false;
|
||||
for (int i = 0; i < *n_tokens; i++) {
|
||||
if (temp_logits[i]) {
|
||||
should_set_batch_logits = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (should_set_batch_logits) {
|
||||
if (*logits != NULL) {
|
||||
free(*logits);
|
||||
*logits = NULL;
|
||||
}
|
||||
*logits = temp_logits;
|
||||
} else {
|
||||
if (*logits != NULL) {
|
||||
free(*logits);
|
||||
*logits = NULL;
|
||||
}
|
||||
free(temp_logits);
|
||||
}
|
||||
} else {
|
||||
free(temp_logits);
|
||||
}
|
||||
|
||||
// For now, we assume that the pos, seq_ids, tokens, etc have been
|
||||
// pre-allocated for the largest possible sizes, even on worker nodes.
|
||||
//if (old_n_tokens != *n_tokens) {
|
||||
// *pos = realloc(*pos, *n_tokens * sizeof(int32_t));
|
||||
// *n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
|
||||
// *tokens = realloc(*tokens, *n_tokens * sizeof(int32_t ));
|
||||
//}
|
||||
|
||||
GGML_ASSERT(n_seq_ids != nullptr);
|
||||
GGML_ASSERT(n_tokens != nullptr);
|
||||
|
||||
|
||||
// FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend
|
||||
// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
|
||||
|
||||
// We need to know the total number of sequence
|
||||
// ids, so we count them all up
|
||||
int32_t total_n_seq_ids = 0;
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
total_n_seq_ids += (*n_seq_ids)[i];
|
||||
// int32_t total_n_seq_ids = 0;
|
||||
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
// total_n_seq_ids += (*n_seq_ids)[i];
|
||||
// }
|
||||
//
|
||||
// // MPI can't chase the pointers for multidimensional arrays, so we flatten them first
|
||||
// // for transit
|
||||
// int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
|
||||
//
|
||||
// int32_t current_index = 0;
|
||||
//
|
||||
// // Only rank 0 needs to flatten since the others don't have the real seq_id
|
||||
// if (ctx_mpi->rank == 0) {
|
||||
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
// flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
||||
// current_index++;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
//
|
||||
//
|
||||
// ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
|
||||
// ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
|
||||
//
|
||||
// current_index = 0;
|
||||
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
// (*seq_id)[i][j] = flattened_seq_ids[current_index];
|
||||
// current_index++;
|
||||
// }
|
||||
//
|
||||
// }
|
||||
// free(flattened_seq_ids);
|
||||
}
|
||||
|
||||
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first
|
||||
// for transit
|
||||
auto * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
|
||||
|
||||
int32_t current_index = 0;
|
||||
|
||||
// Only rank 0 needs to flatten since the others don't have the real seq_id
|
||||
if (ctx_mpi->rank == 0) {
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
||||
current_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
|
||||
auto ** new_seq_id = static_cast<int32_t **>(calloc(*n_tokens, sizeof(int32_t *)));
|
||||
current_index = 0;
|
||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
||||
new_seq_id[i] = static_cast<int32_t *>(calloc((*n_seq_ids)[i], sizeof(int32_t)));
|
||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
new_seq_id[i][j] = flattened_seq_ids[current_index];
|
||||
current_index++;
|
||||
}
|
||||
}
|
||||
free(flattened_seq_ids);
|
||||
//free(*seq_id); // <- something is still holding onto this, need to investigate
|
||||
*seq_id = new_seq_id;
|
||||
}
|
||||
|
||||
void ggml_mpi_synch_int(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
|
|
47
ggml-mpi.h
47
ggml-mpi.h
|
@ -12,6 +12,50 @@ struct ggml_cgraph;
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_MPI_DECODE 0
|
||||
|
||||
#define GGML_MPI_KV_CLEAR 1
|
||||
|
||||
#define GGML_MPI_KV_SEQ_RM 2
|
||||
|
||||
#define GGML_MPI_KV_SEQ_CP 3
|
||||
|
||||
#define GGML_MPI_KV_SEQ_KEEP 4
|
||||
|
||||
#define GGML_MPI_KV_SEQ_SHIFT 5
|
||||
|
||||
#define GGML_MPI_SHUTDOWN 6
|
||||
|
||||
#define GGML_MPI_TRANSFER_TENSORS 7
|
||||
|
||||
#define GGML_MPI_SYNC_LOGITS 8
|
||||
|
||||
#define GGML_MPI_CANCEL_RUN 9
|
||||
|
||||
#define GGML_MPI_KV_SEQ_CP_BACK 10
|
||||
|
||||
#define GGML_MPI_TRANS_ID 11
|
||||
|
||||
#define GGML_MPI_BATCH_ID 12
|
||||
|
||||
#define GGML_MPI_N_TOKENS 13
|
||||
|
||||
#define GGML_MPI_TOKENS 14
|
||||
|
||||
#define GGML_MPI_N_SEQ_IDS 15
|
||||
|
||||
#define GGML_MPI_SEQ_IDS 16
|
||||
|
||||
#define GGML_MPI_POS 17
|
||||
|
||||
#define GGML_MPI_BEGIN_TRANSACTION 18
|
||||
|
||||
#define GGML_MPI_MAX_N_SEQ 19
|
||||
|
||||
#define GGML_MPI_BATCH_LOGITS 20
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* The context used for MPI operations,
|
||||
* a program may make use of more than one
|
||||
|
@ -131,7 +175,8 @@ void ggml_mpi_eval_init(
|
|||
int32_t ** pos,
|
||||
int32_t ** n_seq_ids,
|
||||
int32_t *** seq_id,
|
||||
int8_t ** logits);
|
||||
int8_t ** logits,
|
||||
uint32_t n_seq_max);
|
||||
|
||||
void ggml_mpi_synch_int(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
|
|
21
llama.cpp
21
llama.cpp
|
@ -1750,6 +1750,7 @@ struct llama_cparams {
|
|||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
uint32_t n_seq_max;
|
||||
};
|
||||
|
||||
struct llama_layer {
|
||||
|
@ -8812,7 +8813,7 @@ static int llama_decode_internal(
|
|||
uint32_t n_tokens_all = batch_all.n_tokens;
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits));
|
||||
ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max);
|
||||
n_tokens_all = batch_all.n_tokens;
|
||||
#endif
|
||||
|
||||
|
@ -8896,7 +8897,7 @@ static int llama_decode_internal(
|
|||
seq_id_arr.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_seq_id[i] = 1;
|
||||
seq_id[i].resize(1);
|
||||
seq_id[i].resize(lctx.cparams.n_seq_max);
|
||||
seq_id[i][0] = u_batch.all_seq_id;
|
||||
seq_id_arr[i] = seq_id[i].data();
|
||||
}
|
||||
|
@ -12753,6 +12754,9 @@ void llama_backend_init(void) {
|
|||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_backend_init();
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
|
@ -12760,10 +12764,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
|
|||
if (numa != GGML_NUMA_STRATEGY_DISABLED) {
|
||||
ggml_numa_init(numa);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_backend_init();
|
||||
#endif
|
||||
}
|
||||
|
||||
void llama_backend_free(void) {
|
||||
|
@ -12844,7 +12844,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
const auto & hparams = model->hparams;
|
||||
auto & cparams = ctx->cparams;
|
||||
|
||||
// TODO: maybe add n_seq_max here too
|
||||
cparams.n_seq_max = params.n_seq_max;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
@ -13984,6 +13984,13 @@ void llama_batch_free(struct llama_batch batch) {
|
|||
free(batch.seq_id);
|
||||
}
|
||||
if (batch.logits) free(batch.logits);
|
||||
|
||||
batch.token = nullptr;
|
||||
batch.embd = nullptr;
|
||||
batch.pos = nullptr;
|
||||
batch.n_seq_id = nullptr;
|
||||
batch.seq_id = nullptr;
|
||||
batch.logits = nullptr;
|
||||
}
|
||||
|
||||
int32_t llama_decode(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue