Resize seq_ids by n_seq_max, port over sync_pipelined instead of using Bcast
This commit is contained in:
parent
c6280bc3f4
commit
72dcd66c0f
3 changed files with 207 additions and 61 deletions
200
ggml-mpi.cpp
200
ggml-mpi.cpp
|
@ -14,6 +14,10 @@
|
||||||
|
|
||||||
#define UNUSED GGML_UNUSED
|
#define UNUSED GGML_UNUSED
|
||||||
|
|
||||||
|
static bool have_init = false;
|
||||||
|
|
||||||
|
static void* send_buffer;
|
||||||
|
|
||||||
struct ggml_mpi_context {
|
struct ggml_mpi_context {
|
||||||
int rank;
|
int rank;
|
||||||
int size;
|
int size;
|
||||||
|
@ -26,18 +30,37 @@ struct ggml_mpi_context {
|
||||||
std::vector<ggml_backend_t> backends;
|
std::vector<ggml_backend_t> backends;
|
||||||
ggml_backend_sched_t scheduler;
|
ggml_backend_sched_t scheduler;
|
||||||
bool remote;
|
bool remote;
|
||||||
|
void* send_buffer;
|
||||||
};
|
};
|
||||||
|
|
||||||
void ggml_mpi_backend_init(void) {
|
void ggml_mpi_backend_init(void) {
|
||||||
int ret;
|
int ret;
|
||||||
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &ret);
|
|
||||||
|
GGML_ASSERT(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ret) == MPI_SUCCESS);
|
||||||
|
have_init = true;
|
||||||
|
const int buffer_size = 128*1024*1024*8;
|
||||||
|
send_buffer = calloc(1, buffer_size); // 128MB buffer
|
||||||
|
// fprintf(stderr, "BUFFER ATTACH RETCODE=%d\n", MPI_Buffer_attach(send_buffer, buffer_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_mpi_sync_pipelined(
|
||||||
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
|
void * val,
|
||||||
|
int count,
|
||||||
|
MPI_Datatype datatype,
|
||||||
|
int tag
|
||||||
|
);
|
||||||
|
|
||||||
void ggml_mpi_backend_free(void) {
|
void ggml_mpi_backend_free(void) {
|
||||||
MPI_Finalize();
|
MPI_Finalize();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_mpi_context * ggml_mpi_init(void) {
|
struct ggml_mpi_context * ggml_mpi_init(void) {
|
||||||
|
|
||||||
|
if (!have_init) {
|
||||||
|
ggml_mpi_backend_init();
|
||||||
|
}
|
||||||
|
|
||||||
auto * ctx = new ggml_mpi_context;
|
auto * ctx = new ggml_mpi_context;
|
||||||
|
|
||||||
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
|
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
|
||||||
|
@ -45,6 +68,8 @@ struct ggml_mpi_context * ggml_mpi_init(void) {
|
||||||
ctx->comm = MPI_COMM_WORLD;
|
ctx->comm = MPI_COMM_WORLD;
|
||||||
ctx->remote = false;
|
ctx->remote = false;
|
||||||
|
|
||||||
|
ctx->send_buffer = send_buffer;
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,78 +94,147 @@ size_t ggml_mpi_size(struct ggml_mpi_context * ctx) {
|
||||||
return ctx->size;
|
return ctx->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) {
|
||||||
|
return (ctx_mpi->rank + 1) % ctx_mpi->size;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi) {
|
||||||
|
int temp = (ctx_mpi->rank - 1);
|
||||||
|
return (temp >= 0) ? temp : ctx_mpi->size - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_mpi_sync_pipelined(
|
||||||
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
|
void * val,
|
||||||
|
int count,
|
||||||
|
MPI_Datatype datatype,
|
||||||
|
int tag
|
||||||
|
) {
|
||||||
|
if(ctx_mpi->comm == MPI_COMM_NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// printf("Rank %d sync pipelined with tag %d\n", ctx_mpi->rank, tag);
|
||||||
|
|
||||||
|
|
||||||
|
if (ctx_mpi->rank != 0) {
|
||||||
|
MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE);
|
||||||
|
}
|
||||||
|
if(ctx_mpi->rank < ctx_mpi->size - 1) {
|
||||||
|
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
|
||||||
|
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
|
||||||
|
GGML_ASSERT(retval == MPI_SUCCESS);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_mpi_eval_init(
|
void ggml_mpi_eval_init(
|
||||||
struct ggml_mpi_context * ctx_mpi,
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
int32_t * n_tokens,
|
int32_t * n_tokens,
|
||||||
int32_t ** pos,
|
int32_t ** pos,
|
||||||
int32_t ** n_seq_ids,
|
int32_t ** n_seq_ids,
|
||||||
int32_t *** seq_id,
|
int32_t *** seq_id,
|
||||||
int8_t ** logits) {
|
int8_t ** logits,
|
||||||
|
uint32_t n_seq_max) {
|
||||||
|
|
||||||
|
|
||||||
// fprintf(stderr, "Beginning eval init on rank %d\n", ctx_mpi->rank);
|
if(ctx_mpi->comm == MPI_COMM_NULL) {
|
||||||
MPI_Barrier(ctx_mpi->comm);
|
return;
|
||||||
|
}
|
||||||
int32_t old_n_tokens = *n_tokens;
|
int32_t old_n_tokens = *n_tokens;
|
||||||
MPI_Bcast(n_tokens, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
|
||||||
|
|
||||||
// fprintf(stderr, "Node %d, old_n_tokens: %d, new n_tokens: %d\n", ctx_mpi->rank, old_n_tokens, *n_tokens);
|
|
||||||
|
|
||||||
// If what was passed in differs from what was broadcast,
|
ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS);
|
||||||
// we can't guarantee the allocated sizes are correct
|
int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t));
|
||||||
// TODO check how often this is done and if it's a problem,
|
|
||||||
// try to allocate ahead of time
|
if (ctx_mpi->rank == 0 && *logits != nullptr) {
|
||||||
if (old_n_tokens != *n_tokens) {
|
ggml_mpi_sync_pipelined(ctx_mpi, *logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
|
||||||
*pos = static_cast<int32_t *>(realloc(*pos, *n_tokens * sizeof(int32_t)));
|
} else {
|
||||||
*n_seq_ids = static_cast<int32_t *>(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t)));
|
ggml_mpi_sync_pipelined(ctx_mpi, temp_logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS);
|
||||||
*logits = static_cast<int8_t *>(realloc(*logits, *n_tokens * sizeof(int32_t)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
if (ctx_mpi->rank != 0) {
|
||||||
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
bool should_set_batch_logits = false;
|
||||||
|
for (int i = 0; i < *n_tokens; i++) {
|
||||||
|
if (temp_logits[i]) {
|
||||||
|
should_set_batch_logits = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (should_set_batch_logits) {
|
||||||
|
if (*logits != NULL) {
|
||||||
|
free(*logits);
|
||||||
|
*logits = NULL;
|
||||||
|
}
|
||||||
|
*logits = temp_logits;
|
||||||
|
} else {
|
||||||
|
if (*logits != NULL) {
|
||||||
|
free(*logits);
|
||||||
|
*logits = NULL;
|
||||||
|
}
|
||||||
|
free(temp_logits);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
free(temp_logits);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For now, we assume that the pos, seq_ids, tokens, etc have been
|
||||||
|
// pre-allocated for the largest possible sizes, even on worker nodes.
|
||||||
|
//if (old_n_tokens != *n_tokens) {
|
||||||
|
// *pos = realloc(*pos, *n_tokens * sizeof(int32_t));
|
||||||
|
// *n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
|
||||||
|
// *tokens = realloc(*tokens, *n_tokens * sizeof(int32_t ));
|
||||||
|
//}
|
||||||
|
|
||||||
|
GGML_ASSERT(n_seq_ids != nullptr);
|
||||||
|
GGML_ASSERT(n_tokens != nullptr);
|
||||||
|
|
||||||
|
|
||||||
|
// FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend
|
||||||
|
// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
|
||||||
|
|
||||||
// We need to know the total number of sequence
|
// We need to know the total number of sequence
|
||||||
// ids, so we count them all up
|
// ids, so we count them all up
|
||||||
int32_t total_n_seq_ids = 0;
|
// int32_t total_n_seq_ids = 0;
|
||||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||||
total_n_seq_ids += (*n_seq_ids)[i];
|
// total_n_seq_ids += (*n_seq_ids)[i];
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first
|
// // MPI can't chase the pointers for multidimensional arrays, so we flatten them first
|
||||||
// for transit
|
// // for transit
|
||||||
auto * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
|
// int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
|
||||||
|
//
|
||||||
int32_t current_index = 0;
|
// int32_t current_index = 0;
|
||||||
|
//
|
||||||
// Only rank 0 needs to flatten since the others don't have the real seq_id
|
// // Only rank 0 needs to flatten since the others don't have the real seq_id
|
||||||
if (ctx_mpi->rank == 0) {
|
// if (ctx_mpi->rank == 0) {
|
||||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||||
flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
// flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
||||||
current_index++;
|
// current_index++;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
|
//
|
||||||
MPI_Bcast( *pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
//
|
||||||
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
|
// ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
|
||||||
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
|
// ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
|
||||||
auto ** new_seq_id = static_cast<int32_t **>(calloc(*n_tokens, sizeof(int32_t *)));
|
//
|
||||||
current_index = 0;
|
// current_index = 0;
|
||||||
for (int32_t i = 0; i < *n_tokens; i++) {
|
// for (int32_t i = 0; i < *n_tokens; i++) {
|
||||||
new_seq_id[i] = static_cast<int32_t *>(calloc((*n_seq_ids)[i], sizeof(int32_t)));
|
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||||
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
// (*seq_id)[i][j] = flattened_seq_ids[current_index];
|
||||||
new_seq_id[i][j] = flattened_seq_ids[current_index];
|
// current_index++;
|
||||||
current_index++;
|
// }
|
||||||
}
|
//
|
||||||
}
|
// }
|
||||||
free(flattened_seq_ids);
|
// free(flattened_seq_ids);
|
||||||
//free(*seq_id); // <- something is still holding onto this, need to investigate
|
|
||||||
*seq_id = new_seq_id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_mpi_synch_int(
|
void ggml_mpi_synch_int(
|
||||||
struct ggml_mpi_context * ctx_mpi,
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
int32_t * val
|
int32_t * val
|
||||||
|
|
47
ggml-mpi.h
47
ggml-mpi.h
|
@ -12,6 +12,50 @@ struct ggml_cgraph;
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define GGML_MPI_DECODE 0
|
||||||
|
|
||||||
|
#define GGML_MPI_KV_CLEAR 1
|
||||||
|
|
||||||
|
#define GGML_MPI_KV_SEQ_RM 2
|
||||||
|
|
||||||
|
#define GGML_MPI_KV_SEQ_CP 3
|
||||||
|
|
||||||
|
#define GGML_MPI_KV_SEQ_KEEP 4
|
||||||
|
|
||||||
|
#define GGML_MPI_KV_SEQ_SHIFT 5
|
||||||
|
|
||||||
|
#define GGML_MPI_SHUTDOWN 6
|
||||||
|
|
||||||
|
#define GGML_MPI_TRANSFER_TENSORS 7
|
||||||
|
|
||||||
|
#define GGML_MPI_SYNC_LOGITS 8
|
||||||
|
|
||||||
|
#define GGML_MPI_CANCEL_RUN 9
|
||||||
|
|
||||||
|
#define GGML_MPI_KV_SEQ_CP_BACK 10
|
||||||
|
|
||||||
|
#define GGML_MPI_TRANS_ID 11
|
||||||
|
|
||||||
|
#define GGML_MPI_BATCH_ID 12
|
||||||
|
|
||||||
|
#define GGML_MPI_N_TOKENS 13
|
||||||
|
|
||||||
|
#define GGML_MPI_TOKENS 14
|
||||||
|
|
||||||
|
#define GGML_MPI_N_SEQ_IDS 15
|
||||||
|
|
||||||
|
#define GGML_MPI_SEQ_IDS 16
|
||||||
|
|
||||||
|
#define GGML_MPI_POS 17
|
||||||
|
|
||||||
|
#define GGML_MPI_BEGIN_TRANSACTION 18
|
||||||
|
|
||||||
|
#define GGML_MPI_MAX_N_SEQ 19
|
||||||
|
|
||||||
|
#define GGML_MPI_BATCH_LOGITS 20
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The context used for MPI operations,
|
* The context used for MPI operations,
|
||||||
* a program may make use of more than one
|
* a program may make use of more than one
|
||||||
|
@ -131,7 +175,8 @@ void ggml_mpi_eval_init(
|
||||||
int32_t ** pos,
|
int32_t ** pos,
|
||||||
int32_t ** n_seq_ids,
|
int32_t ** n_seq_ids,
|
||||||
int32_t *** seq_id,
|
int32_t *** seq_id,
|
||||||
int8_t ** logits);
|
int8_t ** logits,
|
||||||
|
uint32_t n_seq_max);
|
||||||
|
|
||||||
void ggml_mpi_synch_int(
|
void ggml_mpi_synch_int(
|
||||||
struct ggml_mpi_context * ctx_mpi,
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
|
|
21
llama.cpp
21
llama.cpp
|
@ -1750,6 +1750,7 @@ struct llama_cparams {
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval;
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
void * cb_eval_user_data;
|
void * cb_eval_user_data;
|
||||||
|
uint32_t n_seq_max;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_layer {
|
struct llama_layer {
|
||||||
|
@ -8812,7 +8813,7 @@ static int llama_decode_internal(
|
||||||
uint32_t n_tokens_all = batch_all.n_tokens;
|
uint32_t n_tokens_all = batch_all.n_tokens;
|
||||||
|
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits));
|
ggml_mpi_eval_init(lctx.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max);
|
||||||
n_tokens_all = batch_all.n_tokens;
|
n_tokens_all = batch_all.n_tokens;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -8896,7 +8897,7 @@ static int llama_decode_internal(
|
||||||
seq_id_arr.resize(n_tokens);
|
seq_id_arr.resize(n_tokens);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
n_seq_id[i] = 1;
|
n_seq_id[i] = 1;
|
||||||
seq_id[i].resize(1);
|
seq_id[i].resize(lctx.cparams.n_seq_max);
|
||||||
seq_id[i][0] = u_batch.all_seq_id;
|
seq_id[i][0] = u_batch.all_seq_id;
|
||||||
seq_id_arr[i] = seq_id[i].data();
|
seq_id_arr[i] = seq_id[i].data();
|
||||||
}
|
}
|
||||||
|
@ -12753,6 +12754,9 @@ void llama_backend_init(void) {
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_USE_MPI
|
||||||
|
ggml_mpi_backend_init();
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12760,10 +12764,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
|
||||||
if (numa != GGML_NUMA_STRATEGY_DISABLED) {
|
if (numa != GGML_NUMA_STRATEGY_DISABLED) {
|
||||||
ggml_numa_init(numa);
|
ggml_numa_init(numa);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_MPI
|
|
||||||
ggml_mpi_backend_init();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_backend_free(void) {
|
void llama_backend_free(void) {
|
||||||
|
@ -12844,7 +12844,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
const auto & hparams = model->hparams;
|
const auto & hparams = model->hparams;
|
||||||
auto & cparams = ctx->cparams;
|
auto & cparams = ctx->cparams;
|
||||||
|
|
||||||
// TODO: maybe add n_seq_max here too
|
cparams.n_seq_max = params.n_seq_max;
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
cparams.n_threads_batch = params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch;
|
||||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||||
|
@ -13984,6 +13984,13 @@ void llama_batch_free(struct llama_batch batch) {
|
||||||
free(batch.seq_id);
|
free(batch.seq_id);
|
||||||
}
|
}
|
||||||
if (batch.logits) free(batch.logits);
|
if (batch.logits) free(batch.logits);
|
||||||
|
|
||||||
|
batch.token = nullptr;
|
||||||
|
batch.embd = nullptr;
|
||||||
|
batch.pos = nullptr;
|
||||||
|
batch.n_seq_id = nullptr;
|
||||||
|
batch.seq_id = nullptr;
|
||||||
|
batch.logits = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_decode(
|
int32_t llama_decode(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue