Synchronize batch sequence info, fixing MPI for llama_decode()

This commit is contained in:
Branden Butler 2023-10-29 15:16:16 -05:00
parent ede7ff0c66
commit bcfb190c28
4 changed files with 93 additions and 22 deletions

View file

@ -1403,7 +1403,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
LOG("warming up the model with an empty run\n");
#ifndef GGML_USE_MPI
// When using MPI, llama_eval() enters into an infinite loop
// When using MPI, llama_decode() enters into an infinite loop
// on non-head nodes. Thus, we only want to warmup the model here
// if we aren't using MPI.
// FIXME have a way to terminate the infinite loop so we can warmup the model

View file

@ -60,16 +60,67 @@ int ggml_mpi_size(struct ggml_mpi_context * ctx) {
}
void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int * n_tokens,
int * n_past,
int * n_threads) {
struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens,
int32_t ** pos,
int32_t ** n_seq_ids,
int32_t *** seq_id,
int8_t ** logits) {
MPI_Barrier(ctx_mpi->comm);
MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm);
MPI_Bcast(n_past, 1, MPI_INT, 0, ctx_mpi->comm);
MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm);
if (ctx_mpi->rank != 0) {
*pos = calloc(*n_tokens, sizeof(int32_t));
*n_seq_ids = calloc(*n_tokens, sizeof(int32_t));
*logits = calloc(*n_tokens, sizeof(int8_t));
}
int32_t total_n_seq_ids = 0;
for (size_t i = 0; i < *n_tokens; i++) {
total_n_seq_ids += (*n_seq_ids)[i];
}
MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t));
int32_t current_index = 0;
if (ctx_mpi->rank == 0) {
for (size_t i = 0; i < *n_tokens; i++) {
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
flattened_seq_ids[current_index] = (*seq_id)[i][j];
current_index++;
}
}
}
MPI_Bcast(*pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*));
current_index = 0;
for (size_t i = 0; i < *n_tokens; i++) {
new_seq_id[i] = calloc((*n_seq_ids)[i], sizeof(int32_t));
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
new_seq_id[i][j] = flattened_seq_ids[current_index];
current_index++;
}
}
free(flattened_seq_ids);
*seq_id = new_seq_id;
}
void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi,
int32_t * val
) {
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
}
static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {

View file

@ -110,14 +110,23 @@ int ggml_mpi_size(struct ggml_mpi_context * ctx);
*
* @param ctx_mpi The context in which to prepare for evaluation.
* @param n_tokens A pointer to the n_tokens, which will be synchronized after this function.
* @param n_past A pointer to the n_past, which will be synchronized after this function.
* @param n_threads A pointer to the n_threads, which is unused currently.
* @param pos A pointer to the pos array, which will be synchronized after this function.
* @param n_seq_ids A pointer to the n_seq_ids array, which will be synchronized after this function.
* @param seq_id A pointer to the seq_id 2D array, which will be synchronized after this function.
* @param logits A pointer to the logits array, which is unused currently since only node 0 needs them.
*/
void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int * n_tokens,
int * n_past,
int * n_threads);
struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens,
int32_t ** pos,
int32_t ** n_seq_ids,
int32_t *** seq_id,
int8_t ** logits);
void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi,
int32_t * val
);
/**
* Split a range across all nodes within the given

View file

@ -8776,8 +8776,7 @@ static int llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
const uint32_t n_tokens_all = batch_all.n_tokens;
uint32_t n_tokens_all = batch_all.n_tokens;
if (n_tokens_all == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
return -1;
@ -8798,11 +8797,7 @@ static int llama_decode_internal(
}
lctx.n_queued_tokens += n_tokens_all;
#ifdef GGML_USE_MPI
// TODO: needs fix after #3228
GGML_ASSERT(false && "not implemented");
//ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
#endif
auto & kv_self = lctx.kv_self;
@ -8828,7 +8823,7 @@ static int llama_decode_internal(
std::vector<std::vector<llama_seq_id>> seq_id;
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
/* .n_tokens = */ (int32_t) n_tokens,
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
@ -8881,7 +8876,12 @@ static int llama_decode_internal(
kv_self.head = 0;
}
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
#ifdef GGML_USE_MPI
// TODO: needs fix after #3228
ggml_mpi_eval_init(lctx.ctx_mpi, &(u_batch.n_tokens), &(u_batch.pos), &(u_batch.n_seq_id), &(u_batch.seq_id), &(u_batch.logits));
n_tokens = u_batch.n_tokens;
#endif
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
return 1;
}
@ -13923,6 +13923,17 @@ void llama_batch_free(struct llama_batch batch) {
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
const int n_ctx = llama_n_ctx(ctx);
std::vector<llama_token> tmp(n_ctx, llama_token_bos(&ctx->model));
while (llama_decode_internal(*ctx, batch) >= 0){};
llama_backend_free();
exit(1);
}
#endif
const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);