mpi : extend API to allow usage with outer backends (e.g. Metal)

This commit is contained in:
Georgi Gerganov 2023-07-10 18:47:24 +03:00
parent c3c3ef11a6
commit eaef2d0e76
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 47 additions and 41 deletions

1
.gitignore vendored
View file

@ -20,6 +20,7 @@ build-static/
build-cublas/
build-opencl/
build-metal/
build-mpi/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/

View file

@ -450,6 +450,7 @@ void ggml_metal_graph_compute(
//}
switch (dst->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:

View file

@ -102,12 +102,10 @@ static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) {
}
// TODO: there are many improvements that can be done to this implementation
void ggml_mpi_graph_compute(
void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
int n_layers,
int n_threads) {
int n_layers) {
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
@ -200,10 +198,16 @@ void ggml_mpi_graph_compute(
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
}
}
ggml_graph_compute_with_ctx(ctx, gf, n_threads);
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers) {
UNUSED(n_layers);
//fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank);
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
// send the output data to the next node
if (mpi_rank > 0) {

View file

@ -24,12 +24,15 @@ void ggml_mpi_eval_init(
int * n_past,
int * n_threads);
void ggml_mpi_graph_compute(
void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
int n_layers,
int n_threads);
int n_layers);
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf,
int n_layers);
#ifdef __cplusplus
}

View file

@ -1632,6 +1632,10 @@ static bool llama_eval_internal(
// run the computation
ggml_build_forward_expand(&gf, cur);
#if GGML_USE_MPI
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
#endif
#ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
@ -1656,14 +1660,19 @@ static bool llama_eval_internal(
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
}
#elif GGML_USE_MPI
ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads);
cur = gf.nodes[gf.n_nodes - 1];
#else
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
#endif
#if GGML_USE_MPI
ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer);
#endif
// update kv token count
lctx.kv_self.n = n_past + N;
struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1];
if (cgraph_fname) {
ggml_graph_export(&gf, cgraph_fname);
}
@ -1679,38 +1688,26 @@ static bool llama_eval_internal(
// ggml_graph_dump_dot(&gf, NULL, "llama.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N);
// update kv token count
lctx.kv_self.n = n_past + N;
#ifdef GGML_USE_MPI
if (ggml_mpi_rank(lctx.ctx_mpi) == 0) {
#else
// extract logits
{
#endif
// extract logits
{
auto & logits_out = lctx.logits;
auto & logits_out = lctx.logits;
if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
}
// extract embeddings
if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding;
// extract embeddings
if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding;
embedding_out.resize(n_embd);
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}
embedding_out.resize(n_embd);
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}
if (mem_per_token == 0) {