Synchronize batch sequence info, fixing MPI for llama_decode()
This commit is contained in:
parent
ede7ff0c66
commit
bcfb190c28
4 changed files with 93 additions and 22 deletions
|
@ -1403,7 +1403,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
#ifndef GGML_USE_MPI
|
||||
// When using MPI, llama_eval() enters into an infinite loop
|
||||
// When using MPI, llama_decode() enters into an infinite loop
|
||||
// on non-head nodes. Thus, we only want to warmup the model here
|
||||
// if we aren't using MPI.
|
||||
// FIXME have a way to terminate the infinite loop so we can warmup the model
|
||||
|
|
63
ggml-mpi.c
63
ggml-mpi.c
|
@ -60,16 +60,67 @@ int ggml_mpi_size(struct ggml_mpi_context * ctx) {
|
|||
}
|
||||
|
||||
void ggml_mpi_eval_init(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int * n_tokens,
|
||||
int * n_past,
|
||||
int * n_threads) {
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * n_tokens,
|
||||
int32_t ** pos,
|
||||
int32_t ** n_seq_ids,
|
||||
int32_t *** seq_id,
|
||||
int8_t ** logits) {
|
||||
|
||||
|
||||
MPI_Barrier(ctx_mpi->comm);
|
||||
|
||||
MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(n_past, 1, MPI_INT, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(n_tokens, 1, MPI_INT, 0, ctx_mpi->comm);
|
||||
|
||||
if (ctx_mpi->rank != 0) {
|
||||
*pos = calloc(*n_tokens, sizeof(int32_t));
|
||||
*n_seq_ids = calloc(*n_tokens, sizeof(int32_t));
|
||||
*logits = calloc(*n_tokens, sizeof(int8_t));
|
||||
}
|
||||
|
||||
int32_t total_n_seq_ids = 0;
|
||||
for (size_t i = 0; i < *n_tokens; i++) {
|
||||
total_n_seq_ids += (*n_seq_ids)[i];
|
||||
}
|
||||
|
||||
MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(*n_seq_ids, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
|
||||
int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t));
|
||||
|
||||
int32_t current_index = 0;
|
||||
|
||||
if (ctx_mpi->rank == 0) {
|
||||
for (size_t i = 0; i < *n_tokens; i++) {
|
||||
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
flattened_seq_ids[current_index] = (*seq_id)[i][j];
|
||||
current_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
MPI_Bcast(*pos, *n_tokens, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
MPI_Bcast(flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
|
||||
int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*));
|
||||
current_index = 0;
|
||||
for (size_t i = 0; i < *n_tokens; i++) {
|
||||
new_seq_id[i] = calloc((*n_seq_ids)[i], sizeof(int32_t));
|
||||
for (size_t j = 0; j < (*n_seq_ids)[i]; j++) {
|
||||
new_seq_id[i][j] = flattened_seq_ids[current_index];
|
||||
current_index++;
|
||||
}
|
||||
}
|
||||
free(flattened_seq_ids);
|
||||
*seq_id = new_seq_id;
|
||||
}
|
||||
|
||||
void ggml_mpi_synch_int(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * val
|
||||
) {
|
||||
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
}
|
||||
|
||||
static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
|
||||
|
|
21
ggml-mpi.h
21
ggml-mpi.h
|
@ -110,14 +110,23 @@ int ggml_mpi_size(struct ggml_mpi_context * ctx);
|
|||
*
|
||||
* @param ctx_mpi The context in which to prepare for evaluation.
|
||||
* @param n_tokens A pointer to the n_tokens, which will be synchronized after this function.
|
||||
* @param n_past A pointer to the n_past, which will be synchronized after this function.
|
||||
* @param n_threads A pointer to the n_threads, which is unused currently.
|
||||
* @param pos A pointer to the pos array, which will be synchronized after this function.
|
||||
* @param n_seq_ids A pointer to the n_seq_ids array, which will be synchronized after this function.
|
||||
* @param seq_id A pointer to the seq_id 2D array, which will be synchronized after this function.
|
||||
* @param logits A pointer to the logits array, which is unused currently since only node 0 needs them.
|
||||
*/
|
||||
void ggml_mpi_eval_init(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int * n_tokens,
|
||||
int * n_past,
|
||||
int * n_threads);
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * n_tokens,
|
||||
int32_t ** pos,
|
||||
int32_t ** n_seq_ids,
|
||||
int32_t *** seq_id,
|
||||
int8_t ** logits);
|
||||
|
||||
void ggml_mpi_synch_int(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
int32_t * val
|
||||
);
|
||||
|
||||
/**
|
||||
* Split a range across all nodes within the given
|
||||
|
|
29
llama.cpp
29
llama.cpp
|
@ -8776,8 +8776,7 @@ static int llama_decode_internal(
|
|||
llama_context & lctx,
|
||||
llama_batch batch_all) { // TODO: rename back to batch
|
||||
|
||||
const uint32_t n_tokens_all = batch_all.n_tokens;
|
||||
|
||||
uint32_t n_tokens_all = batch_all.n_tokens;
|
||||
if (n_tokens_all == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
||||
return -1;
|
||||
|
@ -8798,11 +8797,7 @@ static int llama_decode_internal(
|
|||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
// TODO: needs fix after #3228
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
//ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
|
||||
#endif
|
||||
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
|
@ -8828,7 +8823,7 @@ static int llama_decode_internal(
|
|||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
||||
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
llama_batch u_batch = {
|
||||
/* .n_tokens = */ (int32_t) n_tokens,
|
||||
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
|
||||
|
@ -8881,7 +8876,12 @@ static int llama_decode_internal(
|
|||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
|
||||
#ifdef GGML_USE_MPI
|
||||
// TODO: needs fix after #3228
|
||||
ggml_mpi_eval_init(lctx.ctx_mpi, &(u_batch.n_tokens), &(u_batch.pos), &(u_batch.n_seq_id), &(u_batch.seq_id), &(u_batch.logits));
|
||||
n_tokens = u_batch.n_tokens;
|
||||
#endif
|
||||
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -13923,6 +13923,17 @@ void llama_batch_free(struct llama_batch batch) {
|
|||
int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
|
||||
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
std::vector<llama_token> tmp(n_ctx, llama_token_bos(&ctx->model));
|
||||
while (llama_decode_internal(*ctx, batch) >= 0){};
|
||||
llama_backend_free();
|
||||
exit(1);
|
||||
}
|
||||
#endif
|
||||
const int ret = llama_decode_internal(*ctx, batch);
|
||||
if (ret < 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue