Working MPI backend implementation
This commit is contained in:
parent
bc93545005
commit
942ce843f8
2 changed files with 444 additions and 84 deletions
507
ggml-mpi.cpp
507
ggml-mpi.cpp
|
@ -25,6 +25,7 @@ struct ggml_mpi_context {
|
|||
struct ggml_backend * wrapped_backend;
|
||||
std::vector<ggml_backend_t> backends;
|
||||
ggml_backend_sched_t scheduler;
|
||||
bool remote;
|
||||
};
|
||||
|
||||
void ggml_mpi_backend_init(void) {
|
||||
|
@ -42,6 +43,7 @@ struct ggml_mpi_context * ggml_mpi_init(void) {
|
|||
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
|
||||
ctx->comm = MPI_COMM_WORLD;
|
||||
ctx->remote = false;
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
@ -76,10 +78,13 @@ void ggml_mpi_eval_init(
|
|||
int8_t ** logits) {
|
||||
|
||||
|
||||
// fprintf(stderr, "Beginning eval init on rank %d\n", ctx_mpi->rank);
|
||||
MPI_Barrier(ctx_mpi->comm);
|
||||
int32_t old_n_tokens = *n_tokens;
|
||||
MPI_Bcast(n_tokens, 1, MPI_INT32_T, 0, ctx_mpi->comm);
|
||||
|
||||
// fprintf(stderr, "Node %d, old_n_tokens: %d, new n_tokens: %d\n", ctx_mpi->rank, old_n_tokens, *n_tokens);
|
||||
|
||||
// If what was passed in differs from what was broadcast,
|
||||
// we can't guarantee the allocated sizes are correct
|
||||
// TODO check how often this is done and if it's a problem,
|
||||
|
@ -160,21 +165,30 @@ static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
|
||||
static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst, MPI_Comm comm) {
|
||||
static void ggml_mpi_tensor_send(const struct ggml_tensor * t, const void* data, int mpi_rank_dst, MPI_Comm comm) {
|
||||
MPI_Datatype mpi_type;
|
||||
|
||||
// fprintf(stderr, "Type: %d\n", t->type);
|
||||
|
||||
switch (t->type) {
|
||||
case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break;
|
||||
case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break;
|
||||
case GGML_TYPE_F16: mpi_type = MPI_INT16_T; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
int rank;
|
||||
MPI_Comm_rank(comm, &rank);
|
||||
// fprintf(stderr, "Sending tensor %s (buffer %s) from %d to %d\n", t->name, ggml_backend_buffer_name(t->buffer), rank, mpi_rank_dst);
|
||||
|
||||
const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, comm);
|
||||
const int retval = MPI_Send(data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, comm);
|
||||
GGML_ASSERT(retval == MPI_SUCCESS);
|
||||
|
||||
}
|
||||
static void ggml_mpi_tensor_send(const struct ggml_tensor * t, int mpi_rank_dst, MPI_Comm comm) {
|
||||
ggml_mpi_tensor_send(t, t->data, mpi_rank_dst, comm);
|
||||
}
|
||||
|
||||
static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src, MPI_Comm comm) {
|
||||
static void ggml_mpi_tensor_recv(const struct ggml_tensor * t, void * data, int mpi_rank_src, MPI_Comm comm) {
|
||||
MPI_Datatype mpi_type;
|
||||
|
||||
switch (t->type) {
|
||||
|
@ -184,11 +198,18 @@ static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src, MPI_C
|
|||
}
|
||||
|
||||
MPI_Status status; UNUSED(status);
|
||||
fprintf(stderr, "%s: tensor receive == null: %d\n", __func__, t->data == NULL);
|
||||
const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, comm, &status);
|
||||
// fprintf(stderr, "%s: tensor receive == null: %d\n", __func__, t->data == NULL);
|
||||
int rank;
|
||||
MPI_Comm_rank(comm, &rank);
|
||||
// fprintf(stderr, "Receiving tensor %s (buffer %s) from %d at %d\n", t->name, ggml_backend_buffer_name(t->buffer), mpi_rank_src, rank);
|
||||
const int retval = MPI_Recv(data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, comm, &status);
|
||||
GGML_ASSERT(retval == MPI_SUCCESS);
|
||||
}
|
||||
|
||||
static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src, MPI_Comm comm) {
|
||||
ggml_mpi_tensor_recv(t, t->data, mpi_rank_src, comm);
|
||||
}
|
||||
|
||||
uint16_t** ggml_mpi_split_range(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
uint16_t start,
|
||||
|
@ -296,12 +317,6 @@ void ggml_mpi_graph_compute_pre(
|
|||
const int mpi_rank = ctx_mpi->rank;
|
||||
const int mpi_size = ctx_mpi->size;
|
||||
|
||||
struct ggml_tensor * inp_tokens = gf->nodes[0];
|
||||
if (inp_tokens == NULL) {
|
||||
fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
struct ggml_tensor * inp0 = gf->nodes[0];
|
||||
if (inp0 == NULL) {
|
||||
fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
|
||||
|
@ -309,23 +324,24 @@ void ggml_mpi_graph_compute_pre(
|
|||
}
|
||||
|
||||
if (mpi_rank > 0) {
|
||||
if (mpi_rank == 1) {
|
||||
// the first node (1) receives the input tokens from the main node (0)
|
||||
if (inp_tokens->data == NULL) {
|
||||
|
||||
}
|
||||
ggml_mpi_tensor_recv(inp_tokens, 0, ctx_mpi->comm);
|
||||
} else {
|
||||
// recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph)
|
||||
fprintf(stderr, "%s:%d: receiving layer inp0\n", __func__, ctx_mpi->rank);
|
||||
ggml_mpi_tensor_recv(inp0, mpi_rank - 1, ctx_mpi->comm);
|
||||
}
|
||||
// ggml_mpi_tensor_recv(inp0, mpi_rank - 1, ctx_mpi->comm);
|
||||
// if (mpi_rank == 1) {
|
||||
// // the first node (1) receives the input tokens from the main node (0)
|
||||
// if (inp_tokens->data == NULL) {
|
||||
//
|
||||
// }
|
||||
// ggml_mpi_tensor_recv(inp_tokens, 0, ctx_mpi->comm);
|
||||
// } else {
|
||||
// // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph)
|
||||
// fprintf(stderr, "%s:%d: receiving layer inp0\n", __func__, ctx_mpi->rank);
|
||||
// ggml_mpi_tensor_recv(inp0, mpi_rank - 1, ctx_mpi->comm);
|
||||
// }
|
||||
} else if (mpi_size > 1) {
|
||||
// node 0 sends the input tokens to node 1
|
||||
ggml_mpi_tensor_send(inp_tokens, 1, ctx_mpi->comm);
|
||||
// ggml_mpi_tensor_send(inp_tokens, 1, ctx_mpi->comm);
|
||||
|
||||
// recv the output data from the last node
|
||||
ggml_mpi_tensor_recv(inp0, mpi_size - 1, ctx_mpi->comm);
|
||||
// ggml_mpi_tensor_recv(inp0, mpi_size - 1, ctx_mpi->comm);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -338,27 +354,48 @@ void ggml_mpi_graph_compute_post(
|
|||
|
||||
// send the output data to the next node
|
||||
if (mpi_rank > 0) {
|
||||
ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size, ctx_mpi->comm);
|
||||
// ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size, ctx_mpi->comm);
|
||||
}
|
||||
}
|
||||
|
||||
// BACKEND V2
|
||||
|
||||
struct ggml_backend_mpi_buffer_context {
|
||||
ggml_backend_buffer_t wrapped_buffer;
|
||||
int rank;
|
||||
};
|
||||
|
||||
struct ggml_backend_mpi_buffer_type_context {
|
||||
std::string name;
|
||||
ggml_backend_buffer_type_t wrapped_buffer;
|
||||
int rank;
|
||||
};
|
||||
|
||||
GGML_CALL static const char * ggml_backend_mpi_buffer_name(ggml_backend_buffer_t buffer) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
|
||||
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
return strdup((("MPI Buffer(Rank " + std::to_string(ctx->rank) + ", local rank " + std::to_string(rank) + "):") + std::string(ctx->wrapped_buffer->iface.get_name(ctx->wrapped_buffer))).c_str());
|
||||
}
|
||||
|
||||
GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft);
|
||||
|
||||
GGML_CALL static bool ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
|
||||
struct ggml_mpi_context * ctx = (ggml_mpi_context *) backend->context;
|
||||
|
||||
ggml_mpi_graph_compute_pre(ctx, cgraph);
|
||||
|
||||
// if (ctx->remote) {
|
||||
// return true;
|
||||
// }
|
||||
|
||||
|
||||
// ggml_mpi_graph_compute_pre(ctx, cgraph);
|
||||
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
for (auto *curr_backend : ctx->backends) {
|
||||
for (auto *curr_backend: ctx->backends) {
|
||||
if (ggml_backend_is_cpu(curr_backend)) {
|
||||
// use host buffers for the CPU backend compute buffer
|
||||
backend_buft.push_back(ggml_backend_cpu_buffer_type());
|
||||
|
@ -369,83 +406,202 @@ GGML_CALL static bool ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggm
|
|||
|
||||
// ggml_backend_t wrapped_backend = ctx->wrapped_backend;
|
||||
// bool ret = ggml_backend_graph_compute(wrapped_backend, cgraph);
|
||||
printf("Running MPI backend\n");
|
||||
// printf("Running MPI backend\n");
|
||||
|
||||
std::vector<std::pair<ggml_backend_buffer_type_t, std::vector<ggml_backend_buffer_type_t>> > old_buffs(cgraph->n_nodes);
|
||||
std::vector<std::pair<ggml_backend_buffer_t, std::vector<ggml_backend_buffer_type_t>>> old_buffs(
|
||||
cgraph->n_nodes);
|
||||
std::vector<ggml_backend_buffer_type_t> old_view_buffs(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
old_buffs.push_back({cgraph->nodes[i]->buffer->buft,{}});
|
||||
if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
cgraph->nodes[i]->buffer->buft = ((ggml_backend_mpi_buffer_type_context *) cgraph->nodes[i]->buffer->buft->context)->wrapped_buffer;
|
||||
printf("Unwrapped buffer: %s\n", cgraph->nodes[i]->buffer->buft->iface.get_name(cgraph->nodes[i]->buffer->buft));
|
||||
}
|
||||
|
||||
for (auto & src : cgraph->nodes[i]->src) {
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
old_buffs[i].first = cgraph->nodes[i]->buffer;
|
||||
// if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
// cgraph->nodes[i]->buffer = ((ggml_backend_mpi_buffer_context *) cgraph->nodes[i]->buffer->context)->wrapped_buffer;
|
||||
//// printf("Unwrapped buffer: %s\n", cgraph->nodes[i]->buffer->buft->iface.get_name(cgraph->nodes[i]->buffer->buft));
|
||||
// }
|
||||
|
||||
for (auto &src: cgraph->nodes[i]->src) {
|
||||
if (src == nullptr) {
|
||||
break;
|
||||
}
|
||||
old_buffs[i].second.push_back(src->buffer->buft);
|
||||
// if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
// src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer;
|
||||
//// printf("Unwrapped buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
|
||||
// }
|
||||
}
|
||||
|
||||
auto *src = cgraph->nodes[i]->view_src;
|
||||
if (src != nullptr) {
|
||||
// fprintf(stderr, "View src is not null, src=%s, src buffer=%s\n", src->name, ggml_backend_buffer_name(src->buffer));
|
||||
if (src->buffer->buft != nullptr) {
|
||||
// fprintf(stderr, "View src buffer type is not null, buft=%s\n", ggml_backend_buft_name(src->buffer->buft));
|
||||
old_view_buffs[i] = src->buffer->buft;
|
||||
// if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
// src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer;
|
||||
// printf("Unwrapped view buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
|
||||
// }
|
||||
} else {
|
||||
// old_view_buffs[i] = ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type());
|
||||
// ggml_backend_mpi_buffer_type_set_rank(old_view_buffs[i], ((ggml_backend_mpi_buffer_context*)src->buffer->context)->rank);
|
||||
}
|
||||
} else {
|
||||
// fprintf(stderr, "OLD VIEW BUFF IS NULL\n");
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
auto usage = cgraph->nodes[i]->buffer->usage;
|
||||
cgraph->nodes[i]->buffer = ((ggml_backend_mpi_buffer_context *) cgraph->nodes[i]->buffer->context)->wrapped_buffer;
|
||||
cgraph->nodes[i]->buffer->usage = usage;
|
||||
// printf("Unwrapped buffer: %s\n", cgraph->nodes[i]->buffer->buft->iface.get_name(cgraph->nodes[i]->buffer->buft));
|
||||
}
|
||||
|
||||
for (auto &src: cgraph->nodes[i]->src) {
|
||||
if (src == nullptr) {
|
||||
break;
|
||||
}
|
||||
// old_buffs[i].second.push_back(src->buffer->buft);
|
||||
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
src->buffer->buft = ((ggml_backend_mpi_buffer_type_context *) src->buffer->buft->context)->wrapped_buffer;
|
||||
printf("Unwrapped buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
|
||||
auto usage = src->buffer->usage;
|
||||
src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer;
|
||||
src->buffer->usage = usage;
|
||||
// printf("Unwrapped buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
|
||||
}
|
||||
}
|
||||
|
||||
auto *src = cgraph->nodes[i]->view_src;
|
||||
if(src != nullptr && src->buffer->buft != nullptr){
|
||||
old_view_buffs[i] = src->buffer->buft;
|
||||
if (src != nullptr) {
|
||||
// fprintf(stderr, "View src is not null, src=%s, src buffer=%s\n", src->name, ggml_backend_buffer_name(src->buffer));
|
||||
if (src->buffer->buft != nullptr) {
|
||||
// fprintf(stderr, "View src buffer type is not null, buft=%s\n", ggml_backend_buft_name(src->buffer->buft));
|
||||
// old_view_buffs[i] = src->buffer->buft;
|
||||
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
src->buffer->buft = ((ggml_backend_mpi_buffer_type_context *) src->buffer->buft->context)->wrapped_buffer;
|
||||
printf("Unwrapped view buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
|
||||
auto usage = src->buffer->usage;
|
||||
|
||||
src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer;
|
||||
src->buffer->usage = usage;
|
||||
// printf("Unwrapped view buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
|
||||
}
|
||||
} else {
|
||||
// old_view_buffs[i] = ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type());
|
||||
// ggml_backend_mpi_buffer_type_set_rank(old_view_buffs[i], ((ggml_backend_mpi_buffer_context*)src->buffer->context)->rank);
|
||||
}
|
||||
} else {
|
||||
// fprintf(stderr, "OLD VIEW BUFF IS NULL\n");
|
||||
}
|
||||
}
|
||||
|
||||
// fprintf(stderr, "Original n_leafs: %d\n", cgraph->n_leafs);
|
||||
|
||||
std::vector<ggml_backend_buffer_type_t > old_buffs_leaves;
|
||||
std::vector<ggml_backend_buffer_type_t> old_buffs_leaves;
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
// fprintf(stderr, "Pushing leaf %s\n", cgraph->leafs[i]->name);
|
||||
old_buffs_leaves.push_back(cgraph->leafs[i]->buffer->buft);
|
||||
if (cgraph->leafs[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
cgraph->leafs[i]->buffer->buft = ((ggml_backend_mpi_buffer_type_context *) cgraph->leafs[i]->buffer->buft->context)->wrapped_buffer;
|
||||
printf("Unwrapped buffer: %s\n", cgraph->leafs[i]->buffer->buft->iface.get_name(cgraph->leafs[i]->buffer->buft));
|
||||
// printf("Unwrapped buffer: %s\n", cgraph->leafs[i]->buffer->buft->iface.get_name(cgraph->leafs[i]->buffer->buft));
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), cgraph->n_nodes);
|
||||
|
||||
if (!ctx->remote) {
|
||||
ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(),
|
||||
ctx->backends.size(), cgraph->n_nodes);
|
||||
|
||||
printf("Created new scheduler\n");
|
||||
// printf("Created new scheduler\n");
|
||||
ggml_backend_sched_init_measure(sched, cgraph);
|
||||
printf("Beginning sched graph compute\n");
|
||||
// printf("Beginning sched graph compute\n");
|
||||
ggml_backend_sched_graph_compute(sched, cgraph);
|
||||
|
||||
ggml_backend_sched_free(sched);
|
||||
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
cgraph->nodes[i]->buffer->buft = old_buffs[i].first;
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (cgraph->nodes[i]->src[j] == nullptr) {
|
||||
// fprintf(stderr, "Wrapping buffer %s for node %s\n", cgraph->nodes[i]->name, ggml_backend_buffer_name(cgraph->nodes[i]->buffer));
|
||||
cgraph->nodes[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->nodes[i]->buffer);
|
||||
|
||||
// fprintf(stderr, "Setting buffer ranks for node %s with old buff %s\n", cgraph->nodes[i]->name, ggml_backend_buffer_name(old_buffs[i].first));
|
||||
|
||||
ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->buffer, ((ggml_backend_mpi_buffer_context*)old_buffs[i].first->context)->rank);
|
||||
|
||||
// fprintf(stderr, "New buffer rank for node %s: %d\n", cgraph->nodes[i]->name, ctx->rank);
|
||||
|
||||
|
||||
for (int iter = 0; iter < GGML_MAX_SRC; iter++) {
|
||||
auto* j = cgraph->nodes[i]->src[iter];
|
||||
if (j == nullptr) {
|
||||
break;
|
||||
}
|
||||
cgraph->nodes[i]->src[j]->buffer->buft = old_buffs[i].second[j];
|
||||
|
||||
if (j->buffer->iface.get_name == ggml_backend_mpi_buffer_name) {
|
||||
// fprintf(stderr, "Skipping buffer ranks for src node %s with buffer type %s\n", j->name, j->buffer->buft->iface.get_name(j->buffer->buft));
|
||||
continue;
|
||||
}
|
||||
|
||||
// fprintf(stderr, "Setting buffer ranks for src node %s\n", j->name);
|
||||
|
||||
|
||||
// fprintf(stderr, "Source buffer name: %s, buffer type name: %s\n",
|
||||
// j->buffer->iface.get_name(j->buffer),
|
||||
// j->buffer->buft->iface.get_name(j->buffer->buft));
|
||||
|
||||
j->buffer = ggml_backend_mpi_wrap_buffer(j->buffer);
|
||||
|
||||
ggml_backend_mpi_buffer_set_rank(j->buffer, ((ggml_backend_mpi_buffer_type_context*)old_buffs[i].second[iter]->context)->rank);
|
||||
// j->buffer->buft = old_buffs[i].second[iter];
|
||||
|
||||
// fprintf(stderr, "New source buffer name: %s, buffer type name: %s\n",
|
||||
// j->buffer->iface.get_name(j->buffer),
|
||||
// j->buffer->buft->iface.get_name(j->buffer->buft));
|
||||
|
||||
// ggml_backend_mpi_buffer_type_set_rank(j->buffer->buft, ctx->rank);
|
||||
}
|
||||
if(cgraph->nodes[i]->view_src != nullptr && cgraph->nodes[i]->view_src->buffer->buft != nullptr) {
|
||||
cgraph->nodes[i]->view_src->buffer->buft = old_view_buffs[i];
|
||||
// fprintf(stderr, "View source %s (buffer name: %s), buffer type name: %s\n", cgraph->nodes[i]->view_src->name,
|
||||
// cgraph->nodes[i]->view_src->buffer->iface.get_name(cgraph->nodes[i]->view_src->buffer),
|
||||
// cgraph->nodes[i]->view_src->buffer->buft->iface.get_name(cgraph->nodes[i]->view_src->buffer->buft));
|
||||
|
||||
// fprintf(stderr, "Old view source buffer type name: %s\n",
|
||||
// old_view_buffs[i]->iface.get_name(old_view_buffs[i]));
|
||||
// cgraph->nodes[i]->view_src->buffer = ggml_backend_mpi_wrap_buffer(cgraph->nodes[i]->view_src->buffer);
|
||||
// WRONG, need to keep the source ranks from before compute
|
||||
// fprintf(stderr, "View buff %s null\n", (old_view_buffs[i]->context == nullptr) ? " is " : "is NOT ");
|
||||
if (old_view_buffs[i] != nullptr) {
|
||||
if (old_view_buffs[i]->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->view_src->buffer,
|
||||
((ggml_backend_mpi_buffer_type_context *) old_view_buffs[i]->context)->rank);
|
||||
} else {
|
||||
// ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->view_src->buffer,
|
||||
// ctx->rank);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
cgraph->leafs[i]->buffer->buft = old_buffs_leaves[i];
|
||||
// fprintf(stderr, "Wrapping leaf %s...\n", cgraph->leafs[i]->name);
|
||||
cgraph->leafs[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->leafs[i]->buffer);
|
||||
// fprintf(stderr, "Wrapped leaf buffer: %s\n", ggml_backend_buffer_name(cgraph->leafs[i]->buffer));
|
||||
ggml_backend_mpi_buffer_type_set_rank(cgraph->leafs[i]->buffer->buft, ctx->rank);
|
||||
// fprintf(stderr, "Wrapped leaf after setting rank: %s\n", ggml_backend_buffer_name(cgraph->leafs[i]->buffer));
|
||||
}
|
||||
|
||||
|
||||
ggml_mpi_graph_compute_post(ctx, cgraph);
|
||||
// ggml_mpi_graph_compute_post(ctx, cgraph);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static const char * ggml_backend_mpi_name(ggml_backend_t backend) {
|
||||
return "MPI";
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
return strdup(("MPI(Rank " + std::to_string(((ggml_mpi_context*)backend->context)->rank) + ", local rank " + std::to_string(rank) + ")").c_str());
|
||||
}
|
||||
|
||||
static void ggml_backend_mpi_free(ggml_backend_t backend) {
|
||||
|
@ -459,8 +615,15 @@ static void ggml_backend_mpi_free(ggml_backend_t backend) {
|
|||
|
||||
static ggml_backend_buffer_type_t ggml_backend_mpi_get_default_buffer_type(ggml_backend_t backend) {
|
||||
auto * ctx = static_cast<ggml_mpi_context *>(backend->context);
|
||||
if (ctx->backends.empty()) {
|
||||
auto * buff = ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type());
|
||||
ggml_backend_mpi_buffer_type_set_rank(buff, ctx->rank);
|
||||
return buff;
|
||||
}
|
||||
|
||||
return ggml_backend_mpi_wrap_buffer(ctx->backends.back()->iface.get_default_buffer_type(ctx->backends.back()));
|
||||
auto * buff = ggml_backend_mpi_wrap_buffer_type(ctx->backends.back()->iface.get_default_buffer_type(ctx->backends.back()));
|
||||
ggml_backend_mpi_buffer_type_set_rank(buff, ctx->rank);
|
||||
return buff;
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_mpi_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
|
@ -506,17 +669,35 @@ GGML_CALL bool ggml_backend_is_mpi(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
|
||||
|
||||
|
||||
GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
|
||||
|
||||
return strdup(((ctx->name + ":") + std::string(ctx->wrapped_buffer->iface.get_name(ctx->wrapped_buffer))).c_str());
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
return strdup(((ctx->name + " Buffer Type(Rank " + std::to_string(ctx->rank) + ", local rank " + std::to_string(rank) + "):") + std::string(ctx->wrapped_buffer->iface.get_name(ctx->wrapped_buffer))).c_str());
|
||||
}
|
||||
|
||||
GGML_CALL static ggml_backend_buffer_t ggml_backend_mpi_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
// fprintf(stderr, "ALLOCATING NEW BUFFER FOR BUFFER_TYPE %s\n", ggml_backend_buft_name(buft));
|
||||
|
||||
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
|
||||
ggml_backend_buffer_t buf = ctx->wrapped_buffer->iface.alloc_buffer(ctx->wrapped_buffer, size);
|
||||
buf->buft = ggml_backend_mpi_wrap_buffer(buf->buft);
|
||||
return buf;
|
||||
|
||||
// fprintf(stderr, "WRAPPED BUFFER_TYPE %s\n", ggml_backend_buft_name(ctx->wrapped_buffer));
|
||||
|
||||
|
||||
auto* buffer = ggml_backend_mpi_wrap_buffer(ctx->wrapped_buffer->iface.alloc_buffer(ctx->wrapped_buffer, size));
|
||||
|
||||
// fprintf(stderr, "NEW BUFFER: %s, BUFFER_TYPE: %s\n", ggml_backend_buffer_name(buffer), ggml_backend_buft_name(buffer->buft));
|
||||
|
||||
|
||||
ggml_backend_mpi_buffer_set_rank(buffer, ctx->rank);
|
||||
|
||||
// fprintf(stderr, "NEW BUFFER AFTER SETTING RANK: %s, BUFFER_TYPE: %s\n", ggml_backend_buffer_name(buffer), ggml_backend_buft_name(buffer->buft));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
|
@ -530,31 +711,42 @@ GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_max_size(ggml_backend_b
|
|||
}
|
||||
|
||||
GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
// fprintf(stderr, "GETTING ALLOC SIZE FOR TENSOR %s (%s) AND BUFFER TYPE %s\n", tensor->name, ggml_backend_buffer_name(tensor->buffer), ggml_backend_buft_name(buft));
|
||||
|
||||
|
||||
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
|
||||
return ctx->wrapped_buffer->iface.get_alloc_size(ctx->wrapped_buffer, tensor);
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_mpi_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
||||
return backend != nullptr && ggml_backend_is_mpi(backend);
|
||||
return backend != nullptr && ggml_backend_is_mpi(backend) && ((ggml_backend_mpi_buffer_type_context*) buft->context)->rank == ((ggml_mpi_context*)backend->context)->rank;
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_mpi_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
|
||||
|
||||
return ctx->wrapped_buffer->iface.is_host(ctx->wrapped_buffer);
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
return ctx->rank == rank && ctx->wrapped_buffer->iface.is_host(ctx->wrapped_buffer);
|
||||
}
|
||||
|
||||
|
||||
static std::map<ggml_backend_buffer_type_t, ggml_backend_buffer_type_t> cached_wrappers;
|
||||
|
||||
static std::map<ggml_backend_buffer_t, ggml_backend_buffer_t> cached_buffer_wrappers;
|
||||
|
||||
static std::map<ggml_backend_t *, ggml_backend_t> cached_backends;
|
||||
|
||||
GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_backend_buffer_type_t buft) {
|
||||
|
||||
GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_type_t buft) {
|
||||
|
||||
if (cached_wrappers.find(buft) != cached_wrappers.end()) {
|
||||
return cached_wrappers[buft];
|
||||
}
|
||||
// if (cached_wrappers.find(buft) != cached_wrappers.end()) {
|
||||
// fprintf(stderr, "Returning cached buffer type with name %s\n", cached_wrappers[buft]->iface.get_name(cached_wrappers[buft]));
|
||||
//
|
||||
// auto * ret = new ggml_backend_buffer_type;
|
||||
// *ret = *cached_wrappers[buft];
|
||||
// return ret;
|
||||
// }
|
||||
|
||||
ggml_backend_buffer_type_i ggml_backend_mpi_buffer_type_interface = {
|
||||
/* .get_name = */ ggml_backend_mpi_buffer_type_name,
|
||||
|
@ -566,9 +758,11 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer(ggml_backend_b
|
|||
/* .is_host = */ (buft->iface.is_host != nullptr ) ? ggml_backend_mpi_buffer_type_is_host : nullptr,
|
||||
};
|
||||
|
||||
|
||||
|
||||
auto* ggml_backend_wrapped_buffer_type = new ggml_backend_buffer_type{
|
||||
/* .iface = */ ggml_backend_mpi_buffer_type_interface,
|
||||
/* .context = */ new ggml_backend_mpi_buffer_type_context{"MPI",buft},
|
||||
/* .context = */ new ggml_backend_mpi_buffer_type_context{"MPI",buft, 0},
|
||||
};
|
||||
|
||||
cached_wrappers[buft] = ggml_backend_wrapped_buffer_type;
|
||||
|
@ -576,26 +770,169 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer(ggml_backend_b
|
|||
return ggml_backend_wrapped_buffer_type;
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends) {
|
||||
|
||||
if (cached_backends.find(wrapped_backends) != cached_backends.end()) {
|
||||
return cached_backends[wrapped_backends];
|
||||
|
||||
GGML_CALL static void * ggml_backend_mpi_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
|
||||
return ctx->wrapped_buffer->iface.get_base(ctx->wrapped_buffer);
|
||||
}
|
||||
|
||||
GGML_CALL static void ggml_backend_mpi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
|
||||
return ctx->wrapped_buffer->iface.free_buffer(ctx->wrapped_buffer);
|
||||
}
|
||||
|
||||
GGML_CALL static void ggml_backend_mpi_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
|
||||
// fprintf(stderr, "SETTING TENSOR WITHOUT MPI CALLS FOR %s (%s) AND TGT BUFFER %s\n", tensor->name, ggml_backend_buffer_name(tensor->buffer), ggml_backend_buffer_name(buffer));
|
||||
return ctx->wrapped_buffer->iface.set_tensor(ctx->wrapped_buffer, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
GGML_CALL static void ggml_backend_mpi_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
|
||||
auto * type_ctx = (ggml_backend_mpi_buffer_type_context *) buffer->buft->context;
|
||||
|
||||
|
||||
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
int src_rank = ((ggml_backend_mpi_buffer_type_context*)tensor->buffer->buft->context)->rank;
|
||||
|
||||
if (rank != src_rank) {
|
||||
|
||||
ggml_mpi_tensor_recv(tensor, data, ((ggml_backend_mpi_buffer_type_context*)tensor->buffer->buft->context)->rank, MPI_COMM_WORLD);
|
||||
return;
|
||||
}
|
||||
|
||||
// fprintf(stderr, "GETTING TENSOR WITH SRC RANK=RANK (%d) FOR TENSOR %s (%s) AND TGT BUFFER %s\n", src_rank, tensor->name, ggml_backend_buffer_name(tensor->buffer), ggml_backend_buffer_name(buffer));
|
||||
ctx->wrapped_buffer->iface.get_tensor(ctx->wrapped_buffer, tensor, data, offset, size);
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_mpi_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
|
||||
|
||||
// fprintf(stderr, "DOING LOCAL TENSOR COPY FOR SRC %s (%s) AND DST %s (%s) WITH TGT BUFFER %s\n", src->name, ggml_backend_buffer_name(src->buffer), dst->name, ggml_backend_buffer_name(dst->buffer), ggml_backend_buffer_name(buffer));
|
||||
|
||||
return ctx->wrapped_buffer->iface.cpy_tensor(ctx->wrapped_buffer, src, dst);
|
||||
}
|
||||
|
||||
GGML_CALL static void ggml_backend_mpi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
|
||||
return ctx->wrapped_buffer->iface.clear(ctx->wrapped_buffer, value);
|
||||
}
|
||||
|
||||
static struct ggml_backend_buffer_i mpi_backend_buffer_i = {
|
||||
/* .get_name = */ ggml_backend_mpi_buffer_name,
|
||||
/* .free_buffer = */ ggml_backend_mpi_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_mpi_buffer_get_base,
|
||||
/* .init_tensor = */ NULL, // no initialization required
|
||||
/* .set_tensor = */ ggml_backend_mpi_buffer_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_mpi_buffer_get_tensor,
|
||||
/* .cpy_tensor = */ ggml_backend_mpi_buffer_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_mpi_buffer_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf) {
|
||||
|
||||
// if (buf->iface.get_name == mpi_backend_buffer_i.get_name) {
|
||||
// return buf;
|
||||
// }
|
||||
|
||||
// if (cached_buffer_wrappers.find(buf) != cached_buffer_wrappers.end()) {
|
||||
// fprintf(stderr, "Returning cached buffer with name %s\n", cached_buffer_wrappers[buf]->iface.get_name(cached_buffer_wrappers[buf]));
|
||||
// auto * ret = new ggml_backend_buffer;
|
||||
// *ret = *cached_buffer_wrappers[buf];
|
||||
// auto * ret_type = new ggml_backend_buffer_type;
|
||||
// *ret_type = *ret->buft;
|
||||
// ret->buft = ret_type;
|
||||
// return ret;
|
||||
// }
|
||||
|
||||
|
||||
ggml_backend_buffer_type_t t = ggml_backend_mpi_wrap_buffer_type(buf->buft);
|
||||
|
||||
auto *buffer = new ggml_backend_buffer {
|
||||
/* .interface = */ mpi_backend_buffer_i,
|
||||
/* .buft = */ t,
|
||||
/* .context = */ new ggml_backend_mpi_buffer_context{buf, ((ggml_backend_mpi_buffer_type_context*)t->context)->rank},
|
||||
/* .size = */ buf->size,
|
||||
/* .usage = */ buf->usage
|
||||
};
|
||||
|
||||
cached_buffer_wrappers[buf] = buffer;
|
||||
|
||||
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
bool ggml_backend_mpi_cpy_tensor_async(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
int src_rank = ((ggml_backend_mpi_buffer_type_context*)src->buffer->buft->context)->rank;
|
||||
int dst_rank = ((ggml_backend_mpi_buffer_type_context*)dst->buffer->buft->context)->rank;
|
||||
|
||||
// fprintf(stderr, "Running tensor async copy for src %s (buffer %s) and dst %s (buffer %s) with backend %s\n", src->name, ggml_backend_buffer_name(src->buffer), dst->name, ggml_backend_buffer_name(dst->buffer), backend->iface.get_name(backend));
|
||||
|
||||
auto * ctx = static_cast<ggml_mpi_context *>(backend->context);
|
||||
|
||||
if (ctx->remote) {
|
||||
|
||||
// fprintf(stderr, "Skipping tensor copy for remote backend %s.\n", backend->iface.get_name(backend));
|
||||
return true;
|
||||
}
|
||||
|
||||
if (src_rank == dst_rank) {
|
||||
// src->buffer->iface.cpy_tensor(src->buffer, src, dst);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (src_rank == ctx->rank) {
|
||||
ggml_mpi_tensor_send(src, dst_rank, ctx->comm);
|
||||
} else if (dst_rank == ctx->rank){
|
||||
ggml_mpi_tensor_recv(dst, src_rank, ctx->comm);
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
void ggml_backend_mpi_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * dst, const void* data, size_t offset, size_t size) {
|
||||
int dst_rank = ((ggml_backend_mpi_buffer_type_context*)dst->buffer->buft->context)->rank;
|
||||
|
||||
// fprintf(stderr, "Running set tensor for dst %s (buffer %s) with backend %s\n", dst->name, ggml_backend_buffer_name(dst->buffer), backend->iface.get_name(backend));
|
||||
|
||||
|
||||
auto * ctx = static_cast<ggml_mpi_context *>(backend->context);
|
||||
|
||||
GGML_ASSERT(ctx->rank == dst_rank);
|
||||
|
||||
ggml_mpi_tensor_send(dst, data, ctx->rank, ctx->comm);
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends, int rank) {
|
||||
|
||||
|
||||
ggml_mpi_context * ctx = ggml_mpi_init();
|
||||
std::vector<ggml_backend_t> wrapped_backends_v;
|
||||
if (ctx->rank == rank) {
|
||||
for (size_t i = 0; i < num_backends; i++) {
|
||||
wrapped_backends_v.push_back(wrapped_backends[i]);
|
||||
}
|
||||
} else {
|
||||
ctx->remote = true;
|
||||
}
|
||||
ctx->backends = wrapped_backends_v;
|
||||
|
||||
ctx->rank = rank;
|
||||
struct ggml_backend_i mpi_backend_i = {
|
||||
/* .get_name = */ ggml_backend_mpi_name,
|
||||
/* .free = */ ggml_backend_mpi_free,
|
||||
/* .get_default_buffer_type = */ ggml_backend_mpi_get_default_buffer_type,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .set_tensor_async = */ ggml_backend_mpi_set_tensor_async,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ ggml_backend_mpi_cpy_tensor_async,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
|
@ -617,9 +954,10 @@ ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t n
|
|||
static ggml_backend_t ggml_backend_reg_mpi_init(const char * params, void * user_data) {
|
||||
// TODO check what the parameters are for. Could use it to setup the MPI comms and routes?
|
||||
GGML_UNUSED(params);
|
||||
ggml_mpi_backend_init();
|
||||
auto * v = new std::vector<ggml_backend_t>();
|
||||
v->push_back(ggml_backend_cpu_init());
|
||||
return ggml_backend_mpi_init(v->data(), 1);
|
||||
return ggml_backend_mpi_init(v->data(), 1, 0);
|
||||
}
|
||||
|
||||
|
||||
|
@ -633,7 +971,7 @@ int ggml_backend_mpi_reg_devices() {
|
|||
ggml_backend_register(
|
||||
device.name,
|
||||
ggml_backend_reg_mpi_init,
|
||||
ggml_backend_mpi_wrap_buffer(ggml_backend_cpu_buffer_type()),
|
||||
ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type()),
|
||||
reinterpret_cast<void *>(intptr_t(device.index))
|
||||
);
|
||||
}
|
||||
|
@ -642,4 +980,19 @@ int ggml_backend_mpi_reg_devices() {
|
|||
|
||||
|
||||
|
||||
GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t buft, int rank) {
|
||||
if (buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
((ggml_backend_mpi_buffer_type_context *) buft->context)->rank = rank;
|
||||
} else {
|
||||
GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type");
|
||||
}
|
||||
}
|
||||
|
||||
GGML_CALL void ggml_backend_mpi_buffer_set_rank(ggml_backend_buffer_t buf, int rank) {
|
||||
if (buf->iface.get_name == ggml_backend_mpi_buffer_name) {
|
||||
((ggml_backend_mpi_buffer_context *) buf->context)->rank = rank;
|
||||
ggml_backend_mpi_buffer_type_set_rank(buf->buft, rank);
|
||||
} else {
|
||||
GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type");
|
||||
}
|
||||
}
|
||||
|
|
11
ggml-mpi.h
11
ggml-mpi.h
|
@ -53,7 +53,10 @@ struct ggml_mpi_context * ggml_mpi_init(void);
|
|||
|
||||
void ggml_mpi_graph_creation_post(struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * cgraph, int n_layers);
|
||||
|
||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_type_t buft);
|
||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_backend_buffer_type_t buft);
|
||||
|
||||
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf);
|
||||
|
||||
|
||||
/**
|
||||
* Create a new context by splitting the given context's
|
||||
|
@ -210,7 +213,11 @@ struct ggml_mpi_device {
|
|||
#define MPI_BACKEND_NAME "MPI"
|
||||
GGML_CALL int ggml_backend_mpi_reg_devices();
|
||||
|
||||
GGML_CALL ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends);
|
||||
GGML_CALL ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends, int rank);
|
||||
|
||||
GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t buft, int rank);
|
||||
|
||||
GGML_CALL void ggml_backend_mpi_buffer_set_rank(ggml_backend_buffer_t buft, int rank);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue