Port transactions from mpi-speculative, fix incorrect seq_id syncing (not tested)

This commit is contained in:
Branden Butler 2024-03-19 13:42:49 -05:00
parent 57ac2e7268
commit d2de181d95
3 changed files with 248 additions and 47 deletions

View file

@ -6,8 +6,8 @@
#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <cstdio>
#include <cstdlib>
#include <vector>
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@ -24,6 +24,8 @@ struct ggml_mpi_context {
MPI_Comm comm;
int layer_start;
int layer_end;
MPI_Status status;
struct ggml_tensor *inp0;
std::string name;
struct ggml_backend * wrapped_backend;
@ -31,6 +33,8 @@ struct ggml_mpi_context {
ggml_backend_sched_t scheduler;
bool remote;
void* send_buffer;
int trans_id;
int recv_trans_id;
};
void ggml_mpi_backend_init(void) {
@ -122,12 +126,43 @@ void ggml_mpi_sync_pipelined(
}
if(ctx_mpi->rank < ctx_mpi->size - 1) {
GGML_ASSERT(ctx_mpi->send_buffer != nullptr);
GGML_ASSERT(val != nullptr);
GGML_ASSERT(count < 128*1024*1024);
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
GGML_ASSERT(retval == MPI_SUCCESS);
}
}
void ggml_mpi_barrier(struct ggml_mpi_context * ctx_mpi) {
MPI_Barrier(ctx_mpi->comm);
}
void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag) {
MPI_Probe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &(ctx_mpi->status));
}
int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag) {
if(ctx_mpi->comm == MPI_COMM_NULL) {
return 0;
}
int ret;
MPI_Iprobe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &ret, &(ctx_mpi->status));
return ret;
}
int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi) {
return ctx_mpi->status.MPI_TAG;
}
int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi) {
int32_t count;
MPI_Get_count(&ctx_mpi->status, MPI_INT32_T, &count);
return count;
}
void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens,
@ -142,8 +177,15 @@ void ggml_mpi_eval_init(
return;
}
int32_t old_n_tokens = *n_tokens;
ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS);
if (old_n_tokens != *n_tokens) {
*pos = static_cast<int32_t *>(realloc(*pos, *n_tokens * sizeof(int32_t)));
*n_seq_ids = static_cast<int32_t *>(realloc(*n_seq_ids, *n_tokens * sizeof(int32_t)));
*logits = static_cast<int8_t *>(realloc(*logits, *n_tokens * sizeof(int32_t)));
}
int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t));
if (ctx_mpi->rank == 0 && *logits != nullptr) {
@ -183,49 +225,51 @@ void ggml_mpi_eval_init(
// pre-allocated for the largest possible sizes, even on worker nodes.
GGML_ASSERT(n_seq_ids != nullptr);
GGML_ASSERT(*n_seq_ids != nullptr);
GGML_ASSERT(n_tokens != nullptr);
// FIXME Syncing n_seq_ids causes MPI to throw an invalid buffer error in Bsend
// ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS);
// We need to know the total number of sequence
// ids, so we count them all up
// int32_t total_n_seq_ids = 0;
// for (int32_t i = 0; i < *n_tokens; i++) {
// total_n_seq_ids += (*n_seq_ids)[i];
// }
//
// // MPI can't chase the pointers for multidimensional arrays, so we flatten them first
// // for transit
// int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
//
// int32_t current_index = 0;
//
// // Only rank 0 needs to flatten since the others don't have the real seq_id
// if (ctx_mpi->rank == 0) {
// for (int32_t i = 0; i < *n_tokens; i++) {
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
// flattened_seq_ids[current_index] = (*seq_id)[i][j];
// current_index++;
// }
// }
// }
//
//
//
// ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
// ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
//
// current_index = 0;
// for (int32_t i = 0; i < *n_tokens; i++) {
// for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
// (*seq_id)[i][j] = flattened_seq_ids[current_index];
// current_index++;
// }
//
// }
// free(flattened_seq_ids);
int32_t total_n_seq_ids = 0;
for (int32_t i = 0; i < *n_tokens; i++) {
total_n_seq_ids += (*n_seq_ids)[i];
}
// MPI can't chase the pointers for multidimensional arrays, so we flatten them first
// for transit
int32_t * flattened_seq_ids = static_cast<int32_t *>(calloc(total_n_seq_ids, sizeof(int32_t)));
int32_t current_index = 0;
// Only rank 0 needs to flatten since the others don't have the real seq_id
if (ctx_mpi->rank == 0) {
for (int32_t i = 0; i < *n_tokens; i++) {
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
flattened_seq_ids[current_index] = (*seq_id)[i][j];
current_index++;
}
}
}
ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS);
ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS);
current_index = 0;
for (int32_t i = 0; i < *n_tokens; i++) {
for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) {
(*seq_id)[i][j] = flattened_seq_ids[current_index];
current_index++;
}
}
free(flattened_seq_ids);
}
@ -236,6 +280,19 @@ void ggml_mpi_sync_int(
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
}
void ggml_mpi_sync_ints_pipelined(
struct ggml_mpi_context * ctx_mpi,
int32_t * vals,
int count,
int tag
) {
ggml_mpi_sync_pipelined(ctx_mpi, vals, count, MPI_INT32_T, tag);
int old_trans = ctx_mpi->trans_id;
ggml_mpi_sync_pipelined(ctx_mpi, &ctx_mpi->trans_id, 1, MPI_INT32_T, GGML_MPI_TRANS_ID);
ctx_mpi->recv_trans_id = ctx_mpi->trans_id;
ctx_mpi->trans_id = old_trans;
}
static void ggml_mpi_tensor_send(const struct ggml_tensor * t, const void* data, int mpi_rank_dst, MPI_Comm comm) {
MPI_Datatype mpi_type;
@ -549,6 +606,8 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
}
}
// TODO exploding memory usage cause we replace the buffer with the wrapped buffer,
// but don't free the contexts, and then create new ones when we re-wrap
if (!ctx->remote) {

View file

@ -100,6 +100,26 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf);
void ggml_mpi_sync_ints_pipelined(
struct ggml_mpi_context * ctx_mpi,
int32_t * vals,
int count,
int tag
);
void ggml_mpi_sync_ints_pipelined_back(
struct ggml_mpi_context * ctx_mpi,
int32_t * vals,
int count,
int tag
);
// clear = 1, rm = 2, cp = 3, keep = 4, seq_shift = 5
void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi);
int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi);
/**
* Create a new context by splitting the given context's
* communicator, creating a "sub-communicator." This is a collective

138
llama.cpp
View file

@ -9064,11 +9064,58 @@ static void llama_graph_compute(
//
static int llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
llama_batch & batch_all) { // TODO: rename back to batch
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(lctx.model.ctx_mpi) == 0 && ggml_mpi_size(lctx.model.ctx_mpi) > 1) {
int transaction_type = GGML_MPI_DECODE;
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION);
}
// ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.batch_id, 1, GGML_MPI_BATCH_ID);
int old_tokens = batch_all.n_tokens;
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, &batch_all.n_tokens, 1, GGML_MPI_N_TOKENS);
ggml_mpi_sync_ints_pipelined(lctx.model.ctx_mpi, reinterpret_cast<int32_t *>(&lctx.cparams.n_seq_max), 1, GGML_MPI_MAX_N_SEQ);
if (ggml_mpi_rank(lctx.model.ctx_mpi) > 0) {
int new_n_tokens = batch_all.n_tokens;
llama_batch_free(batch_all);
batch_all = llama_batch_init(new_n_tokens, 0, (int32_t)lctx.cparams.n_seq_max);
}
#endif
uint32_t n_tokens_all = batch_all.n_tokens;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
if (batch_all.pos == nullptr) {
pos.resize(n_tokens_all);
for (uint32_t i = 0; i < n_tokens_all; i++) {
pos[i] = batch_all.all_pos_0 + i*batch_all.all_pos_1;
}
batch_all.pos = pos.data();
}
if (batch_all.seq_id == nullptr) {
n_seq_id.resize(n_tokens_all);
seq_id.resize(n_tokens_all);
seq_id_arr.resize(n_tokens_all);
for (uint32_t i = 0; i < n_tokens_all; i++) {
n_seq_id[i] = 1;
seq_id[i].resize(lctx.cparams.n_seq_max);
seq_id[i][0] = batch_all.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
batch_all.n_seq_id = n_seq_id.data();
batch_all.seq_id = seq_id_arr.data();
}
#ifdef GGML_USE_MPI
ggml_mpi_eval_init(lctx.model.ctx_mpi, &(batch_all.n_tokens), &(batch_all.pos), &(batch_all.n_seq_id), &(batch_all.seq_id), &(batch_all.logits), lctx.cparams.n_seq_max);
n_tokens_all = batch_all.n_tokens;
@ -9114,10 +9161,6 @@ static int llama_decode_internal(
const auto n_ubatch = cparams.n_ubatch;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
@ -14344,6 +14387,87 @@ void llama_batch_free(struct llama_batch batch) {
batch.logits = nullptr;
}
#ifdef GGML_USE_MPI
int llama_process_mpi_transaction(
struct llama_context * ctx,
struct llama_batch & batch,
int tag) {
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
// printf("\nBeginning transaction type %d\n", tag);
// }
switch (tag) {
case GGML_MPI_DECODE:
// llama_batch_free(batch);
return llama_decode_internal(*ctx, batch);
break;
case GGML_MPI_KV_CLEAR:
llama_kv_cache_clear(ctx);
break;
case GGML_MPI_KV_SEQ_RM:
llama_kv_cache_seq_rm(ctx, 1, -1, -1);
break;
case GGML_MPI_KV_SEQ_CP:
llama_kv_cache_seq_cp(ctx, 0, 0, 0, 0);
break;
// case GGML_MPI_KV_SEQ_CP_BACK:
// llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0);
// break;
// case GGML_MPI_KV_SEQ_KEEP:
// llama_kv_cache_seq_keep(ctx, 0);
// break;
// case GGML_MPI_KV_SEQ_SHIFT:
// llama_kv_cache_seq_shift(ctx, 0, 0, 0, 0);
// break;
default:
printf("Unknown operation, exiting\n");
exit(1);
break;
}
return 0;
}
int llama_process_mpi_worker(
struct llama_context * ctx,
struct llama_batch & batch) {
ggml_mpi_probe(ctx->model.ctx_mpi, -1, -1);
int tag = ggml_mpi_status_tag(ctx->model.ctx_mpi);
int32_t count;
int32_t trans_type;
// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) {
// printf("\nReceived command %d\n", tag);
// }
switch (tag) {
case GGML_MPI_BEGIN_TRANSACTION:
ggml_mpi_sync_ints_pipelined(ctx->model.ctx_mpi, &trans_type, 1, GGML_MPI_BEGIN_TRANSACTION);
return llama_process_mpi_transaction(ctx, batch, trans_type);
break;
case GGML_MPI_SHUTDOWN:
llama_free(ctx);
llama_backend_free();
exit(0);
break;
case GGML_MPI_CANCEL_RUN:
// count = ggml_mpi_status_count_int32(ctx->model.ctx_mpi);
//// printf("Received cancel run\n");
// {
// std::vector<int32_t> canceled(count, -1);
// llama_cancel_run(ctx, canceled.data(), canceled.size());
//
// }
// break;
default:
printf("Unknown operation, exiting\n");
exit(1);
break;
}
return 0;
}
#endif
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
@ -14351,9 +14475,7 @@ int32_t llama_decode(
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(ctx->model.ctx_mpi) > 0) {
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
const int n_ctx = llama_n_ctx(ctx);
std::vector<llama_token> tmp(n_ctx, llama_token_bos(&ctx->model));
while (llama_decode_internal(*ctx, batch) >= 0){};
while (llama_process_mpi_worker(ctx, batch) >= 0){};
llama_backend_free();
exit(1);
}