mpi : extend API to allow usage with outer backends (e.g. Metal)
This commit is contained in:
parent
c3c3ef11a6
commit
eaef2d0e76
5 changed files with 47 additions and 41 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -20,6 +20,7 @@ build-static/
|
||||||
build-cublas/
|
build-cublas/
|
||||||
build-opencl/
|
build-opencl/
|
||||||
build-metal/
|
build-metal/
|
||||||
|
build-mpi/
|
||||||
build-no-accel/
|
build-no-accel/
|
||||||
build-sanitize-addr/
|
build-sanitize-addr/
|
||||||
build-sanitize-thread/
|
build-sanitize-thread/
|
||||||
|
|
|
@ -450,6 +450,7 @@ void ggml_metal_graph_compute(
|
||||||
//}
|
//}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
|
|
16
ggml-mpi.c
16
ggml-mpi.c
|
@ -102,12 +102,10 @@ static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: there are many improvements that can be done to this implementation
|
// TODO: there are many improvements that can be done to this implementation
|
||||||
void ggml_mpi_graph_compute(
|
void ggml_mpi_graph_compute_pre(
|
||||||
struct ggml_mpi_context * ctx_mpi,
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
int n_layers,
|
int n_layers) {
|
||||||
int n_threads) {
|
|
||||||
const int mpi_rank = ctx_mpi->rank;
|
const int mpi_rank = ctx_mpi->rank;
|
||||||
const int mpi_size = ctx_mpi->size;
|
const int mpi_size = ctx_mpi->size;
|
||||||
|
|
||||||
|
@ -200,10 +198,16 @@ void ggml_mpi_graph_compute(
|
||||||
|
|
||||||
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
|
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx, gf, n_threads);
|
void ggml_mpi_graph_compute_post(
|
||||||
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
int n_layers) {
|
||||||
|
UNUSED(n_layers);
|
||||||
|
|
||||||
//fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank);
|
const int mpi_rank = ctx_mpi->rank;
|
||||||
|
const int mpi_size = ctx_mpi->size;
|
||||||
|
|
||||||
// send the output data to the next node
|
// send the output data to the next node
|
||||||
if (mpi_rank > 0) {
|
if (mpi_rank > 0) {
|
||||||
|
|
11
ggml-mpi.h
11
ggml-mpi.h
|
@ -24,12 +24,15 @@ void ggml_mpi_eval_init(
|
||||||
int * n_past,
|
int * n_past,
|
||||||
int * n_threads);
|
int * n_threads);
|
||||||
|
|
||||||
void ggml_mpi_graph_compute(
|
void ggml_mpi_graph_compute_pre(
|
||||||
struct ggml_mpi_context * ctx_mpi,
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
int n_layers,
|
int n_layers);
|
||||||
int n_threads);
|
|
||||||
|
void ggml_mpi_graph_compute_post(
|
||||||
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
int n_layers);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
59
llama.cpp
59
llama.cpp
|
@ -1632,6 +1632,10 @@ static bool llama_eval_internal(
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, cur);
|
ggml_build_forward_expand(&gf, cur);
|
||||||
|
|
||||||
|
#if GGML_USE_MPI
|
||||||
|
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (lctx.ctx_metal && N == 1) {
|
if (lctx.ctx_metal && N == 1) {
|
||||||
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
||||||
|
@ -1656,14 +1660,19 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
||||||
}
|
}
|
||||||
#elif GGML_USE_MPI
|
|
||||||
ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads);
|
|
||||||
|
|
||||||
cur = gf.nodes[gf.n_nodes - 1];
|
|
||||||
#else
|
#else
|
||||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if GGML_USE_MPI
|
||||||
|
ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// update kv token count
|
||||||
|
lctx.kv_self.n = n_past + N;
|
||||||
|
|
||||||
|
struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1];
|
||||||
|
|
||||||
if (cgraph_fname) {
|
if (cgraph_fname) {
|
||||||
ggml_graph_export(&gf, cgraph_fname);
|
ggml_graph_export(&gf, cgraph_fname);
|
||||||
}
|
}
|
||||||
|
@ -1679,38 +1688,26 @@ static bool llama_eval_internal(
|
||||||
// ggml_graph_dump_dot(&gf, NULL, "llama.dot");
|
// ggml_graph_dump_dot(&gf, NULL, "llama.dot");
|
||||||
//}
|
//}
|
||||||
|
|
||||||
//embd_w.resize(n_vocab*N);
|
// extract logits
|
||||||
//memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N);
|
|
||||||
|
|
||||||
// update kv token count
|
|
||||||
lctx.kv_self.n = n_past + N;
|
|
||||||
|
|
||||||
#ifdef GGML_USE_MPI
|
|
||||||
if (ggml_mpi_rank(lctx.ctx_mpi) == 0) {
|
|
||||||
#else
|
|
||||||
{
|
{
|
||||||
#endif
|
auto & logits_out = lctx.logits;
|
||||||
// extract logits
|
|
||||||
{
|
|
||||||
auto & logits_out = lctx.logits;
|
|
||||||
|
|
||||||
if (lctx.logits_all) {
|
if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * N);
|
logits_out.resize(n_vocab * N);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
|
||||||
} else {
|
} else {
|
||||||
// return result for just the last token
|
// return result for just the last token
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// extract embeddings
|
// extract embeddings
|
||||||
if (!lctx.embedding.empty()) {
|
if (!lctx.embedding.empty()) {
|
||||||
auto & embedding_out = lctx.embedding;
|
auto & embedding_out = lctx.embedding;
|
||||||
|
|
||||||
embedding_out.resize(n_embd);
|
embedding_out.resize(n_embd);
|
||||||
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mem_per_token == 0) {
|
if (mem_per_token == 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue