Clean up MPI backend a tad

This commit is contained in:
Branden Butler 2024-03-12 18:17:42 -05:00
parent 72dcd66c0f
commit 5f156f3a0c
2 changed files with 209 additions and 368 deletions

View file

@ -141,7 +141,6 @@ void ggml_mpi_eval_init(
if(ctx_mpi->comm == MPI_COMM_NULL) { if(ctx_mpi->comm == MPI_COMM_NULL) {
return; return;
} }
int32_t old_n_tokens = *n_tokens;
ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS); ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS);
@ -182,11 +181,6 @@ void ggml_mpi_eval_init(
// For now, we assume that the pos, seq_ids, tokens, etc have been // For now, we assume that the pos, seq_ids, tokens, etc have been
// pre-allocated for the largest possible sizes, even on worker nodes. // pre-allocated for the largest possible sizes, even on worker nodes.
//if (old_n_tokens != *n_tokens) {
// *pos = realloc(*pos, *n_tokens * sizeof(int32_t));
// *n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
// *tokens = realloc(*tokens, *n_tokens * sizeof(int32_t ));
//}
GGML_ASSERT(n_seq_ids != nullptr); GGML_ASSERT(n_seq_ids != nullptr);
GGML_ASSERT(n_tokens != nullptr); GGML_ASSERT(n_tokens != nullptr);
@ -235,30 +229,13 @@ void ggml_mpi_eval_init(
} }
void ggml_mpi_synch_int( void ggml_mpi_sync_int(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,
int32_t * val int32_t * val
) { ) {
MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm); MPI_Bcast(val, 1, MPI_INT32_T, 0, ctx_mpi->comm);
} }
static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
struct ggml_tensor * t = ggml_graph_get_tensor(gf, name);
if (t == NULL) {
fprintf(stderr, "%s: tensor %s not found\n", __func__, name);
return -1;
}
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->nodes[i] == t) {
return i;
}
}
fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name);
return -1;
}
static void ggml_mpi_tensor_send(const struct ggml_tensor * t, const void* data, int mpi_rank_dst, MPI_Comm comm) { static void ggml_mpi_tensor_send(const struct ggml_tensor * t, const void* data, int mpi_rank_dst, MPI_Comm comm) {
MPI_Datatype mpi_type; MPI_Datatype mpi_type;
@ -340,154 +317,172 @@ uint16_t** ggml_mpi_split_range(
} }
void ggml_mpi_scatter_layers(
struct ggml_mpi_context * ctx_mpi,
uint16_t ** layer_ranges
) {
// Layer ranges is a 2d array with the first dimension
// having a length of the number of nodes and the second
// dimension having a length of 2. The inner arrays contain
// the start and end layer ID for a node.
uint16_t flattened_ranges[ctx_mpi->size * 2];
if (layer_ranges != NULL) {
for (int i = 0; i < ctx_mpi->size * 2; i += 2) {
flattened_ranges[i] = layer_ranges[i/2][0];
flattened_ranges[i + 1] = layer_ranges[i/2][1];
}
}
uint16_t received_range[2];
MPI_Scatter(flattened_ranges, 2, MPI_UINT16_T, received_range, 2, MPI_UINT16_T, 0, ctx_mpi->comm);
ctx_mpi->layer_start = received_range[0];
ctx_mpi->layer_end = received_range[1];
fprintf(stderr, "Ranges for rank %d: [%d, %d]\n", ctx_mpi->rank, ctx_mpi->layer_start, ctx_mpi->layer_end);
}
void ggml_mpi_graph_creation_post(struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * gf, int n_layers) {
struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens");
if (inp_tokens == NULL) {
fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__);
return;
}
struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0");
if (inp0 == NULL) {
fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
return;
}
ctx_mpi->inp0 = inp0;
// fprintf(stderr, "gf->nodes[0] == %s\n", ggml_get_name(gf->nodes[0]));
//
// GGML_ASSERT(inp0 == gf->nodes[0]);
// distribute the compute graph into slices across the MPI nodes
//
// the main node (0) processes the last layers + the remainder of the compute graph
// and is responsible to pass the input tokens to the first node (1)
//
// node 1: [( 0) * n_per_node, ( 1) * n_per_node)
// node 2: [( 1) * n_per_node, ( 2) * n_per_node)
// ...
// node n-1: [(n-2) * n_per_node, (n-1) * n_per_node)
// node 0: [(n-1) * n_per_node, n_nodes)
//
for (int i = 0; i < gf->n_nodes; i++) {
gf->nodes[i]->backend = GGML_BACKEND_TYPE_MPI_SPLIT;
}
}
// TODO: there are many improvements that can be done to this implementation
void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf) {
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
struct ggml_tensor * inp0 = gf->nodes[0];
if (inp0 == NULL) {
fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
return;
}
if (mpi_rank > 0) {
// ggml_mpi_tensor_recv(inp0, mpi_rank - 1, ctx_mpi->comm);
// if (mpi_rank == 1) {
// // the first node (1) receives the input tokens from the main node (0)
// if (inp_tokens->data == NULL) {
//
// }
// ggml_mpi_tensor_recv(inp_tokens, 0, ctx_mpi->comm);
// } else {
// // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph)
// fprintf(stderr, "%s:%d: receiving layer inp0\n", __func__, ctx_mpi->rank);
// ggml_mpi_tensor_recv(inp0, mpi_rank - 1, ctx_mpi->comm);
// }
} else if (mpi_size > 1) {
// node 0 sends the input tokens to node 1
// ggml_mpi_tensor_send(inp_tokens, 1, ctx_mpi->comm);
// recv the output data from the last node
// ggml_mpi_tensor_recv(inp0, mpi_size - 1, ctx_mpi->comm);
}
}
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf) {
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
// send the output data to the next node
if (mpi_rank > 0) {
// ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size, ctx_mpi->comm);
}
}
// BACKEND V2 // BACKEND V2
struct ggml_backend_mpi_buffer_context { struct ggml_backend_mpi_buffer_context {
ggml_backend_buffer_t wrapped_buffer; ggml_backend_buffer_t wrapped_buffer;
int rank; ggml_mpi_context * ctx_mpi;
}; };
struct ggml_backend_mpi_buffer_type_context { struct ggml_backend_mpi_buffer_type_context {
std::string name; std::string name;
ggml_backend_buffer_type_t wrapped_buffer; ggml_backend_buffer_type_t wrapped_buffer_type;
int rank; ggml_mpi_context * ctx_mpi;
}; };
GGML_CALL static const char * ggml_backend_mpi_buffer_name(ggml_backend_buffer_t buffer) { int ggml_backend_mpi_buffer_type_rank(ggml_backend_buffer_type_t buft);
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
int rank; int ggml_backend_mpi_buffer_type_local_rank(ggml_backend_buffer_type_t buft);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return strdup((("MPI Buffer(Rank " + std::to_string(ctx->rank) + ", local rank " + std::to_string(rank) + "):") + std::string(ctx->wrapped_buffer->iface.get_name(ctx->wrapped_buffer))).c_str()); GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft) {
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
return strdup(
(
ctx->name +
" Buffer Type(Rank " +
std::to_string(
ggml_backend_mpi_buffer_type_rank(buft)
) +
", local rank " +
std::to_string(ggml_backend_mpi_buffer_type_local_rank(buft)) +
"):" +
std::string(
ctx->wrapped_buffer_type->iface.get_name(ctx->wrapped_buffer_type)
)
).c_str()
);
} }
MPI_Comm ggml_backend_mpi_buffer_type_get_comm(ggml_backend_buffer_type_t buft) {
auto * buft_ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
return buft_ctx->ctx_mpi->comm;
}
MPI_Comm ggml_backend_mpi_buffer_get_comm(ggml_backend_buffer_t buffer) {
return ggml_backend_mpi_buffer_type_get_comm(buffer->buft);
}
MPI_Comm ggml_backend_mpi_get_comm(ggml_backend_t backend) {
auto * ctx = (ggml_mpi_context *) backend->context;
return ctx->comm;
}
int ggml_backend_mpi_buffer_local_rank(ggml_backend_buffer_t buffer) {
int rank;
int ret = MPI_Comm_rank(ggml_backend_mpi_buffer_get_comm(buffer), &rank);
GGML_ASSERT(ret == MPI_SUCCESS);
return rank;
}
int ggml_backend_mpi_buffer_type_local_rank(ggml_backend_buffer_type_t buft) {
int rank;
int ret = MPI_Comm_rank(ggml_backend_mpi_buffer_type_get_comm(buft), &rank);
GGML_ASSERT(ret == MPI_SUCCESS);
return rank;
}
int ggml_backend_mpi_local_rank(ggml_backend_t backend) {
int rank;
int ret = MPI_Comm_rank(ggml_backend_mpi_get_comm(backend), &rank);
GGML_ASSERT(ret == MPI_SUCCESS);
return rank;
}
int ggml_backend_mpi_buffer_rank(ggml_backend_buffer_t buffer) {
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
return ctx->ctx_mpi->rank;
}
int ggml_backend_mpi_buffer_type_rank(ggml_backend_buffer_type_t buft) {
GGML_ASSERT(buft->iface.get_name == ggml_backend_mpi_buffer_type_name);
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(ctx->ctx_mpi != nullptr);
return ctx->ctx_mpi->rank;
}
int ggml_backend_mpi_rank(ggml_backend_t backend) {
auto * ctx = (ggml_mpi_context *) backend->context;
return ctx->rank;
}
ggml_backend_buffer_t ggml_backend_mpi_buffer_unwrap(ggml_backend_buffer_t buffer) {
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
ggml_backend_buffer_t wrapped_buffer = ctx->wrapped_buffer;
wrapped_buffer->usage = buffer->usage;
wrapped_buffer->size = buffer->size;
return wrapped_buffer;
}
ggml_backend_buffer_type_t ggml_backend_mpi_buffer_type_unwrap(ggml_backend_buffer_type_t buft) {
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
ggml_backend_buffer_type_t wrapped_buffer_type = ctx->wrapped_buffer_type;
return wrapped_buffer_type;
}
GGML_CALL static const char * ggml_backend_mpi_buffer_name(ggml_backend_buffer_t buffer) {
return strdup(
(
"MPI Buffer(Rank " +
std::to_string(ggml_backend_mpi_buffer_rank(buffer)) +
", local rank " +
std::to_string(ggml_backend_mpi_buffer_local_rank(buffer)) +
"):" +
std::string(
ggml_backend_buffer_name(
ggml_backend_mpi_buffer_unwrap(buffer)
)
)
).c_str()
);
}
GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft); GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft);
GGML_CALL void ggml_backend_mpi_buffer_type_copy_ctx(ggml_backend_buffer_type_t src, ggml_backend_buffer_type_t dst) {
if (src->iface.get_name == ggml_backend_mpi_buffer_type_name) {
*((ggml_backend_mpi_buffer_type_context *) dst->context)->ctx_mpi = *((ggml_backend_mpi_buffer_type_context *) src->context)->ctx_mpi;
} else {
GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type_t");
}
}
GGML_CALL void ggml_backend_mpi_buffer_copy_ctx(ggml_backend_buffer_t src, ggml_backend_buffer_t dst) {
if (src->iface.get_name == ggml_backend_mpi_buffer_name) {
*((ggml_backend_mpi_buffer_context *) dst->context)->ctx_mpi = *((ggml_backend_mpi_buffer_context *) src->context)->ctx_mpi;
ggml_backend_mpi_buffer_type_copy_ctx(src->buft, dst->buft);
} else {
GGML_ASSERT(!"Buffer must be wrapped in ggml_backend_mpi_buffer_t");
}
}
GGML_CALL void ggml_backend_mpi_buffer_copy_ctx_from_type(ggml_backend_buffer_type_t src, ggml_backend_buffer_t dst) {
if (src->iface.get_name == ggml_backend_mpi_buffer_type_name) {
*((ggml_backend_mpi_buffer_context *) dst->context)->ctx_mpi = *((ggml_backend_mpi_buffer_type_context *) src->context)->ctx_mpi;
ggml_backend_mpi_buffer_type_copy_ctx(src, dst->buft);
} else {
GGML_ASSERT(!"Buffer must be wrapped in ggml_backend_mpi_buffer_t");
}
}
GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
struct ggml_mpi_context * ctx = (ggml_mpi_context *) backend->context; struct ggml_mpi_context * ctx = (ggml_mpi_context *) backend->context;
// if (ctx->remote) {
// return true;
// }
// ggml_mpi_graph_compute_pre(ctx, cgraph);
std::vector<ggml_backend_buffer_type_t> backend_buft; std::vector<ggml_backend_buffer_type_t> backend_buft;
for (auto *curr_backend: ctx->backends) { for (auto *curr_backend: ctx->backends) {
if (ggml_backend_is_cpu(curr_backend)) { if (ggml_backend_is_cpu(curr_backend)) {
@ -498,195 +493,115 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
} }
} }
// ggml_backend_t wrapped_backend = ctx->wrapped_backend;
// bool ret = ggml_backend_graph_compute(wrapped_backend, cgraph);
// printf("Running MPI backend\n");
std::vector<std::pair<ggml_backend_buffer_t, std::vector<ggml_backend_buffer_type_t>>> old_buffs( std::vector<std::pair<ggml_backend_buffer_t, std::vector<ggml_backend_buffer_t>>> old_buffs(
cgraph->n_nodes); cgraph->n_nodes);
std::vector<ggml_backend_buffer_type_t> old_view_buffs(cgraph->n_nodes); std::vector<ggml_backend_buffer_t> old_view_buffs(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
old_buffs[i].first = cgraph->nodes[i]->buffer; old_buffs[i].first = cgraph->nodes[i]->buffer;
// if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
// cgraph->nodes[i]->buffer = ((ggml_backend_mpi_buffer_context *) cgraph->nodes[i]->buffer->context)->wrapped_buffer;
//// printf("Unwrapped buffer: %s\n", cgraph->nodes[i]->buffer->buft->iface.get_name(cgraph->nodes[i]->buffer->buft));
// }
for (auto &src: cgraph->nodes[i]->src) { for (auto &src: cgraph->nodes[i]->src) {
if (src == nullptr) { if (src == nullptr) {
break; break;
} }
old_buffs[i].second.push_back(src->buffer->buft); old_buffs[i].second.push_back(src->buffer);
// if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
// src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer;
//// printf("Unwrapped buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
// }
} }
auto *src = cgraph->nodes[i]->view_src; auto *src = cgraph->nodes[i]->view_src;
if (src != nullptr) { if (src != nullptr) {
// fprintf(stderr, "View src is not null, src=%s, src buffer=%s\n", src->name, ggml_backend_buffer_name(src->buffer));
if (src->buffer->buft != nullptr) { if (src->buffer->buft != nullptr) {
// fprintf(stderr, "View src buffer type is not null, buft=%s\n", ggml_backend_buft_name(src->buffer->buft)); old_view_buffs[i] = src->buffer;
old_view_buffs[i] = src->buffer->buft;
// if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
// src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer;
// printf("Unwrapped view buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
// }
} else {
// old_view_buffs[i] = ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type());
// ggml_backend_mpi_buffer_type_set_rank(old_view_buffs[i], ((ggml_backend_mpi_buffer_context*)src->buffer->context)->rank);
} }
} else {
// fprintf(stderr, "OLD VIEW BUFF IS NULL\n");
} }
} }
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
auto usage = cgraph->nodes[i]->buffer->usage; cgraph->nodes[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->nodes[i]->buffer);
cgraph->nodes[i]->buffer = ((ggml_backend_mpi_buffer_context *) cgraph->nodes[i]->buffer->context)->wrapped_buffer;
cgraph->nodes[i]->buffer->usage = usage;
// printf("Unwrapped buffer: %s\n", cgraph->nodes[i]->buffer->buft->iface.get_name(cgraph->nodes[i]->buffer->buft));
} }
for (auto &src: cgraph->nodes[i]->src) { for (auto &src: cgraph->nodes[i]->src) {
if (src == nullptr) { if (src == nullptr) {
break; break;
} }
// old_buffs[i].second.push_back(src->buffer->buft);
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
auto usage = src->buffer->usage; src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer;
src->buffer->usage = usage;
// printf("Unwrapped buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
} }
} }
auto *src = cgraph->nodes[i]->view_src; auto *src = cgraph->nodes[i]->view_src;
if (src != nullptr) { if (src != nullptr) {
// fprintf(stderr, "View src is not null, src=%s, src buffer=%s\n", src->name, ggml_backend_buffer_name(src->buffer));
if (src->buffer->buft != nullptr) { if (src->buffer->buft != nullptr) {
// fprintf(stderr, "View src buffer type is not null, buft=%s\n", ggml_backend_buft_name(src->buffer->buft));
// old_view_buffs[i] = src->buffer->buft;
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
auto usage = src->buffer->usage;
src->buffer = ((ggml_backend_mpi_buffer_context *) src->buffer->context)->wrapped_buffer; if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
src->buffer->usage = usage; src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
// printf("Unwrapped view buffer src: %s\n", src->buffer->buft->iface.get_name(src->buffer->buft));
} }
} else {
// old_view_buffs[i] = ggml_backend_mpi_wrap_buffer_type(ggml_backend_cpu_buffer_type());
// ggml_backend_mpi_buffer_type_set_rank(old_view_buffs[i], ((ggml_backend_mpi_buffer_context*)src->buffer->context)->rank);
} }
} else {
// fprintf(stderr, "OLD VIEW BUFF IS NULL\n");
} }
} }
// fprintf(stderr, "Original n_leafs: %d\n", cgraph->n_leafs);
std::vector<ggml_backend_buffer_type_t> old_buffs_leaves; std::vector<ggml_backend_buffer_type_t> old_buffs_leaves;
for (int i = 0; i < cgraph->n_leafs; i++) { for (int i = 0; i < cgraph->n_leafs; i++) {
// fprintf(stderr, "Pushing leaf %s\n", cgraph->leafs[i]->name);
old_buffs_leaves.push_back(cgraph->leafs[i]->buffer->buft); old_buffs_leaves.push_back(cgraph->leafs[i]->buffer->buft);
if (cgraph->leafs[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (cgraph->leafs[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
cgraph->leafs[i]->buffer->buft = ((ggml_backend_mpi_buffer_type_context *) cgraph->leafs[i]->buffer->buft->context)->wrapped_buffer; cgraph->leafs[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->leafs[i]->buffer);
// printf("Unwrapped buffer: %s\n", cgraph->leafs[i]->buffer->buft->iface.get_name(cgraph->leafs[i]->buffer->buft));
} }
} }
if (!ctx->remote) { if (!ctx->remote) {
ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(),
ctx->backends.size(), cgraph->n_nodes); (int) ctx->backends.size(), cgraph->n_nodes);
// printf("Created new scheduler\n");
ggml_backend_sched_reserve(sched, cgraph); ggml_backend_sched_reserve(sched, cgraph);
// printf("Beginning sched graph compute\n");
ggml_backend_sched_graph_compute(sched, cgraph); ggml_backend_sched_graph_compute(sched, cgraph);
ggml_backend_sched_free(sched); ggml_backend_sched_free(sched);
} }
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
// fprintf(stderr, "Wrapping buffer %s for node %s\n", cgraph->nodes[i]->name, ggml_backend_buffer_name(cgraph->nodes[i]->buffer));
cgraph->nodes[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->nodes[i]->buffer); cgraph->nodes[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->nodes[i]->buffer);
// fprintf(stderr, "Setting buffer ranks for node %s with old buff %s\n", cgraph->nodes[i]->name, ggml_backend_buffer_name(old_buffs[i].first)); ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->buffer, ggml_backend_mpi_buffer_rank(old_buffs[i].first));
ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->buffer, ((ggml_backend_mpi_buffer_context*)old_buffs[i].first->context)->rank);
// fprintf(stderr, "New buffer rank for node %s: %d\n", cgraph->nodes[i]->name, ctx->rank);
for (int iter = 0; iter < GGML_MAX_SRC; iter++) { for (int iter = 0; iter < GGML_MAX_SRC; iter++) {
auto* j = cgraph->nodes[i]->src[iter]; auto* src_node = cgraph->nodes[i]->src[iter];
if (j == nullptr) { if (src_node == nullptr) {
break; break;
} }
if (j->buffer->iface.get_name == ggml_backend_mpi_buffer_name) { if (src_node->buffer->iface.get_name == ggml_backend_mpi_buffer_name) {
// fprintf(stderr, "Skipping buffer ranks for src node %s with buffer type %s\n", j->name, j->buffer->buft->iface.get_name(j->buffer->buft));
continue; continue;
} }
// fprintf(stderr, "Setting buffer ranks for src node %s\n", j->name); src_node->buffer = ggml_backend_mpi_wrap_buffer(src_node->buffer);
ggml_backend_mpi_buffer_set_rank(src_node->buffer, ggml_backend_mpi_buffer_rank(old_buffs[i].second[iter]));
// fprintf(stderr, "Source buffer name: %s, buffer type name: %s\n",
// j->buffer->iface.get_name(j->buffer),
// j->buffer->buft->iface.get_name(j->buffer->buft));
j->buffer = ggml_backend_mpi_wrap_buffer(j->buffer);
ggml_backend_mpi_buffer_set_rank(j->buffer, ((ggml_backend_mpi_buffer_type_context*)old_buffs[i].second[iter]->context)->rank);
// j->buffer->buft = old_buffs[i].second[iter];
// fprintf(stderr, "New source buffer name: %s, buffer type name: %s\n",
// j->buffer->iface.get_name(j->buffer),
// j->buffer->buft->iface.get_name(j->buffer->buft));
// ggml_backend_mpi_buffer_type_set_rank(j->buffer->buft, ctx->rank);
} }
if(cgraph->nodes[i]->view_src != nullptr && cgraph->nodes[i]->view_src->buffer->buft != nullptr) { if(cgraph->nodes[i]->view_src != nullptr && cgraph->nodes[i]->view_src->buffer->buft != nullptr) {
// fprintf(stderr, "View source %s (buffer name: %s), buffer type name: %s\n", cgraph->nodes[i]->view_src->name,
// cgraph->nodes[i]->view_src->buffer->iface.get_name(cgraph->nodes[i]->view_src->buffer),
// cgraph->nodes[i]->view_src->buffer->buft->iface.get_name(cgraph->nodes[i]->view_src->buffer->buft));
// fprintf(stderr, "Old view source buffer type name: %s\n",
// old_view_buffs[i]->iface.get_name(old_view_buffs[i]));
// cgraph->nodes[i]->view_src->buffer = ggml_backend_mpi_wrap_buffer(cgraph->nodes[i]->view_src->buffer);
// WRONG, need to keep the source ranks from before compute
// fprintf(stderr, "View buff %s null\n", (old_view_buffs[i]->context == nullptr) ? " is " : "is NOT ");
if (old_view_buffs[i] != nullptr) { if (old_view_buffs[i] != nullptr) {
if (old_view_buffs[i]->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (old_view_buffs[i]->iface.get_name == ggml_backend_mpi_buffer_name) {
ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->view_src->buffer, ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->view_src->buffer,
((ggml_backend_mpi_buffer_type_context *) old_view_buffs[i]->context)->rank); ggml_backend_mpi_buffer_rank(old_view_buffs[i]));
} else {
// ggml_backend_mpi_buffer_set_rank(cgraph->nodes[i]->view_src->buffer,
// ctx->rank);
} }
} }
} }
} }
// FIXME check if this is correct or not (it's probably not)
for (int i = 0; i < cgraph->n_leafs; i++) { for (int i = 0; i < cgraph->n_leafs; i++) {
// fprintf(stderr, "Wrapping leaf %s...\n", cgraph->leafs[i]->name);
cgraph->leafs[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->leafs[i]->buffer); cgraph->leafs[i]->buffer = ggml_backend_mpi_wrap_buffer(cgraph->leafs[i]->buffer);
// fprintf(stderr, "Wrapped leaf buffer: %s\n", ggml_backend_buffer_name(cgraph->leafs[i]->buffer));
ggml_backend_mpi_buffer_type_set_rank(cgraph->leafs[i]->buffer->buft, ctx->rank); ggml_backend_mpi_buffer_type_set_rank(cgraph->leafs[i]->buffer->buft, ctx->rank);
// fprintf(stderr, "Wrapped leaf after setting rank: %s\n", ggml_backend_buffer_name(cgraph->leafs[i]->buffer));
} }
// ggml_mpi_graph_compute_post(ctx, cgraph);
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
@ -695,7 +610,7 @@ static const char * ggml_backend_mpi_name(ggml_backend_t backend) {
int rank; int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return strdup(("MPI(Rank " + std::to_string(((ggml_mpi_context*)backend->context)->rank) + ", local rank " + std::to_string(rank) + ")").c_str()); return strdup(("MPI(Rank " + std::to_string(ggml_backend_mpi_rank(backend)) + ", local rank " + std::to_string(ggml_backend_mpi_local_rank(backend)) + ")").c_str());
} }
static void ggml_backend_mpi_free(ggml_backend_t backend) { static void ggml_backend_mpi_free(ggml_backend_t backend) {
@ -765,64 +680,39 @@ GGML_CALL bool ggml_backend_is_mpi(ggml_backend_t backend) {
GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft) {
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return strdup(((ctx->name + " Buffer Type(Rank " + std::to_string(ctx->rank) + ", local rank " + std::to_string(rank) + "):") + std::string(ctx->wrapped_buffer->iface.get_name(ctx->wrapped_buffer))).c_str());
}
GGML_CALL static ggml_backend_buffer_t ggml_backend_mpi_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { GGML_CALL static ggml_backend_buffer_t ggml_backend_mpi_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
// fprintf(stderr, "ALLOCATING NEW BUFFER FOR BUFFER_TYPE %s\n", ggml_backend_buft_name(buft));
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context; auto* buffer = ggml_backend_mpi_wrap_buffer(
ggml_backend_buft_alloc_buffer(ggml_backend_mpi_buffer_type_unwrap(buft), size)
);
// fprintf(stderr, "WRAPPED BUFFER_TYPE %s\n", ggml_backend_buft_name(ctx->wrapped_buffer)); ggml_backend_mpi_buffer_copy_ctx_from_type(buft, buffer);
auto* buffer = ggml_backend_mpi_wrap_buffer(ctx->wrapped_buffer->iface.alloc_buffer(ctx->wrapped_buffer, size));
// fprintf(stderr, "NEW BUFFER: %s, BUFFER_TYPE: %s\n", ggml_backend_buffer_name(buffer), ggml_backend_buft_name(buffer->buft));
ggml_backend_mpi_buffer_set_rank(buffer, ctx->rank);
// fprintf(stderr, "NEW BUFFER AFTER SETTING RANK: %s, BUFFER_TYPE: %s\n", ggml_backend_buffer_name(buffer), ggml_backend_buft_name(buffer->buft));
return buffer; return buffer;
} }
GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context; return ggml_backend_buft_get_alignment(ggml_backend_mpi_buffer_type_unwrap(buft));
return ctx->wrapped_buffer->iface.get_alignment(ctx->wrapped_buffer);
} }
GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context; return ggml_backend_buft_get_max_size(ggml_backend_mpi_buffer_type_unwrap(buft));
return ctx->wrapped_buffer->iface.get_max_size(ctx->wrapped_buffer);
} }
GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { GGML_CALL static size_t ggml_backend_mpi_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
// fprintf(stderr, "GETTING ALLOC SIZE FOR TENSOR %s (%s) AND BUFFER TYPE %s\n", tensor->name, ggml_backend_buffer_name(tensor->buffer), ggml_backend_buft_name(buft)); // Have to do this instead of calling ggml_backend_type_get_alloc_size because that signature doesn't have const on tensor
return ggml_backend_mpi_buffer_type_unwrap(buft)->iface.get_alloc_size(ggml_backend_mpi_buffer_type_unwrap(buft), tensor);
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
return ctx->wrapped_buffer->iface.get_alloc_size(ctx->wrapped_buffer, tensor);
} }
GGML_CALL static bool ggml_backend_mpi_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { GGML_CALL static bool ggml_backend_mpi_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
return backend != nullptr && ggml_backend_is_mpi(backend) && ((ggml_backend_mpi_buffer_type_context*) buft->context)->rank == ((ggml_mpi_context*)backend->context)->rank; return backend != nullptr && ggml_backend_is_mpi(backend) && ggml_backend_mpi_buffer_type_rank(buft) == ggml_backend_mpi_rank(backend);
} }
GGML_CALL static bool ggml_backend_mpi_buffer_type_is_host(ggml_backend_buffer_type_t buft) { GGML_CALL static bool ggml_backend_mpi_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
auto * ctx = (ggml_backend_mpi_buffer_type_context *) buft->context;
int rank; return ggml_backend_mpi_buffer_type_rank(buft) == ggml_backend_mpi_buffer_type_local_rank(buft) && ggml_backend_buft_is_host(ggml_backend_mpi_buffer_type_unwrap(buft));
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return ctx->rank == rank && ctx->wrapped_buffer->iface.is_host(ctx->wrapped_buffer);
} }
@ -854,11 +744,18 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_back
auto* ggml_backend_wrapped_buffer_type = new ggml_backend_buffer_type{ auto* ggml_backend_wrapped_buffer_type = new ggml_backend_buffer_type {
/* .iface = */ ggml_backend_mpi_buffer_type_interface, /* .iface = */ ggml_backend_mpi_buffer_type_interface,
/* .context = */ new ggml_backend_mpi_buffer_type_context{"MPI",buft, 0}, /* .context = */ new ggml_backend_mpi_buffer_type_context{
/* .name = */ "MPI",
/* .wrapped_buffer_type = */ buft,
/* .ctx_mpi = */ ggml_mpi_init()
}
}; };
// Set rank to 0 as default
ggml_backend_mpi_buffer_type_set_rank(ggml_backend_wrapped_buffer_type, 0);
cached_wrappers[buft] = ggml_backend_wrapped_buffer_type; cached_wrappers[buft] = ggml_backend_wrapped_buffer_type;
return ggml_backend_wrapped_buffer_type; return ggml_backend_wrapped_buffer_type;
@ -883,37 +780,26 @@ GGML_CALL static void ggml_backend_mpi_buffer_set_tensor(ggml_backend_buffer_t b
} }
GGML_CALL static void ggml_backend_mpi_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_CALL static void ggml_backend_mpi_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context; int rank = ggml_backend_mpi_buffer_local_rank(tensor->buffer);
auto * type_ctx = (ggml_backend_mpi_buffer_type_context *) buffer->buft->context;
int src_rank = ggml_backend_mpi_buffer_rank(tensor->buffer);
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int src_rank = ((ggml_backend_mpi_buffer_type_context*)tensor->buffer->buft->context)->rank;
if (rank != src_rank) { if (rank != src_rank) {
ggml_mpi_tensor_recv(tensor, data, ((ggml_backend_mpi_buffer_type_context*)tensor->buffer->buft->context)->rank, MPI_COMM_WORLD); ggml_mpi_tensor_recv(tensor, data, ggml_backend_mpi_buffer_rank(tensor->buffer), ggml_backend_mpi_buffer_get_comm(tensor->buffer));
return; return;
} }
// fprintf(stderr, "GETTING TENSOR WITH SRC RANK=RANK (%d) FOR TENSOR %s (%s) AND TGT BUFFER %s\n", src_rank, tensor->name, ggml_backend_buffer_name(tensor->buffer), ggml_backend_buffer_name(buffer)); ggml_backend_mpi_buffer_unwrap(buffer)->iface.get_tensor(ggml_backend_mpi_buffer_unwrap(buffer), tensor, data, offset, size);
ctx->wrapped_buffer->iface.get_tensor(ctx->wrapped_buffer, tensor, data, offset, size);
} }
GGML_CALL static bool ggml_backend_mpi_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_CALL static bool ggml_backend_mpi_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context;
// fprintf(stderr, "DOING LOCAL TENSOR COPY FOR SRC %s (%s) AND DST %s (%s) WITH TGT BUFFER %s\n", src->name, ggml_backend_buffer_name(src->buffer), dst->name, ggml_backend_buffer_name(dst->buffer), ggml_backend_buffer_name(buffer)); return ggml_backend_mpi_buffer_unwrap(buffer)->iface.cpy_tensor(ggml_backend_mpi_buffer_unwrap(buffer), src, dst);
return ctx->wrapped_buffer->iface.cpy_tensor(ctx->wrapped_buffer, src, dst);
} }
GGML_CALL static void ggml_backend_mpi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { GGML_CALL static void ggml_backend_mpi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
auto * ctx = (ggml_backend_mpi_buffer_context *) buffer->context; return ggml_backend_mpi_buffer_unwrap(buffer)->iface.clear(ggml_backend_mpi_buffer_unwrap(buffer), value);
return ctx->wrapped_buffer->iface.clear(ctx->wrapped_buffer, value);
} }
static struct ggml_backend_buffer_i mpi_backend_buffer_i = { static struct ggml_backend_buffer_i mpi_backend_buffer_i = {
@ -930,10 +816,6 @@ static struct ggml_backend_buffer_i mpi_backend_buffer_i = {
GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf) { GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf) {
// if (buf->iface.get_name == mpi_backend_buffer_i.get_name) {
// return buf;
// }
// if (cached_buffer_wrappers.find(buf) != cached_buffer_wrappers.end()) { // if (cached_buffer_wrappers.find(buf) != cached_buffer_wrappers.end()) {
// fprintf(stderr, "Returning cached buffer with name %s\n", cached_buffer_wrappers[buf]->iface.get_name(cached_buffer_wrappers[buf])); // fprintf(stderr, "Returning cached buffer with name %s\n", cached_buffer_wrappers[buf]->iface.get_name(cached_buffer_wrappers[buf]));
// auto * ret = new ggml_backend_buffer; // auto * ret = new ggml_backend_buffer;
@ -950,11 +832,15 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer
auto *buffer = new ggml_backend_buffer { auto *buffer = new ggml_backend_buffer {
/* .interface = */ mpi_backend_buffer_i, /* .interface = */ mpi_backend_buffer_i,
/* .buft = */ t, /* .buft = */ t,
/* .context = */ new ggml_backend_mpi_buffer_context{buf, ((ggml_backend_mpi_buffer_type_context*)t->context)->rank}, /* .context = */ new ggml_backend_mpi_buffer_context{
buf, ggml_mpi_init()},
/* .size = */ buf->size, /* .size = */ buf->size,
/* .usage = */ buf->usage /* .usage = */ buf->usage
}; };
// Default to node 0 when wrapping buffers
ggml_backend_mpi_buffer_set_rank(buffer, 0);
cached_buffer_wrappers[buf] = buffer; cached_buffer_wrappers[buf] = buffer;
@ -963,16 +849,12 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer
} }
bool ggml_backend_mpi_cpy_tensor_async(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst) { bool ggml_backend_mpi_cpy_tensor_async(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst) {
int src_rank = ((ggml_backend_mpi_buffer_type_context*)src->buffer->buft->context)->rank; int src_rank = ggml_backend_mpi_buffer_rank(src->buffer);
int dst_rank = ((ggml_backend_mpi_buffer_type_context*)dst->buffer->buft->context)->rank; int dst_rank = ggml_backend_mpi_buffer_rank(dst->buffer);
// fprintf(stderr, "Running tensor async copy for src %s (buffer %s) and dst %s (buffer %s) with backend %s\n", src->name, ggml_backend_buffer_name(src->buffer), dst->name, ggml_backend_buffer_name(dst->buffer), backend->iface.get_name(backend));
auto * ctx = static_cast<ggml_mpi_context *>(backend->context); auto * ctx = static_cast<ggml_mpi_context *>(backend->context);
if (ctx->remote) { if (ctx->remote) {
// fprintf(stderr, "Skipping tensor copy for remote backend %s.\n", backend->iface.get_name(backend));
return true; return true;
} }
@ -991,9 +873,7 @@ bool ggml_backend_mpi_cpy_tensor_async(ggml_backend_t backend, const struct ggml
} }
void ggml_backend_mpi_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * dst, const void* data, size_t offset, size_t size) { void ggml_backend_mpi_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * dst, const void* data, size_t offset, size_t size) {
int dst_rank = ((ggml_backend_mpi_buffer_type_context*)dst->buffer->buft->context)->rank; int dst_rank = ggml_backend_mpi_buffer_rank(dst->buffer);
// fprintf(stderr, "Running set tensor for dst %s (buffer %s) with backend %s\n", dst->name, ggml_backend_buffer_name(dst->buffer), backend->iface.get_name(backend));
auto * ctx = static_cast<ggml_mpi_context *>(backend->context); auto * ctx = static_cast<ggml_mpi_context *>(backend->context);
@ -1079,7 +959,7 @@ int ggml_backend_mpi_reg_devices() {
GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t buft, int rank) { GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t buft, int rank) {
if (buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
((ggml_backend_mpi_buffer_type_context *) buft->context)->rank = rank; ((ggml_backend_mpi_buffer_type_context *) buft->context)->ctx_mpi->rank = rank;
} else { } else {
GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type"); GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type");
} }
@ -1087,9 +967,11 @@ GGML_CALL void ggml_backend_mpi_buffer_type_set_rank(ggml_backend_buffer_type_t
GGML_CALL void ggml_backend_mpi_buffer_set_rank(ggml_backend_buffer_t buf, int rank) { GGML_CALL void ggml_backend_mpi_buffer_set_rank(ggml_backend_buffer_t buf, int rank) {
if (buf->iface.get_name == ggml_backend_mpi_buffer_name) { if (buf->iface.get_name == ggml_backend_mpi_buffer_name) {
((ggml_backend_mpi_buffer_context *) buf->context)->rank = rank; ((ggml_backend_mpi_buffer_context *) buf->context)->ctx_mpi->rank = rank;
ggml_backend_mpi_buffer_type_set_rank(buf->buft, rank); ggml_backend_mpi_buffer_type_set_rank(buf->buft, rank);
} else { } else {
GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type"); GGML_ASSERT(!"Buffer type must be wrapped in ggml_backend_mpi_buffer_type");
} }
} }

View file

@ -95,8 +95,6 @@ void ggml_mpi_backend_free(void);
*/ */
struct ggml_mpi_context * ggml_mpi_init(void); struct ggml_mpi_context * ggml_mpi_init(void);
void ggml_mpi_graph_creation_post(struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * cgraph, int n_layers);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_backend_buffer_type_t buft); GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_mpi_wrap_buffer_type(ggml_backend_buffer_type_t buft);
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf); GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_mpi_wrap_buffer(ggml_backend_buffer_t buf);
@ -178,7 +176,7 @@ void ggml_mpi_eval_init(
int8_t ** logits, int8_t ** logits,
uint32_t n_seq_max); uint32_t n_seq_max);
void ggml_mpi_synch_int( void ggml_mpi_sync_int(
struct ggml_mpi_context * ctx_mpi, struct ggml_mpi_context * ctx_mpi,
int32_t * val int32_t * val
); );
@ -207,45 +205,6 @@ uint16_t** ggml_mpi_split_range(
float node_weights[] float node_weights[]
); );
/**
* Scatter the layer ranges across all nodes
* in the given context. This is a collective operation
* and must be called by all nodes that are within the same
* communicator. The given layer ranges must be in the same
* format as created by the ggml_mpi_split_range().
*
* @param ctx_mpi The context to scatter the layers across.
* @param layer_ranges The pre-split ranges to scatter to the nodes.
*/
void ggml_mpi_scatter_layers(
struct ggml_mpi_context * ctx_mpi,
uint16_t ** layer_ranges
);
/**
* Modify compute graph to only process allocated
* layers.
*
* @param ctx_mpi The context containing the allocated layer range.
* @param gf The compute graph to modify
* @param n_layers The number of layers in the model, used as an upper bound in the layer ranges.
*/
void ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf);
/**
* Sends the output tensor to the next node for processing
* of later layers.
*
* @param ctx_mpi The context to use for MPI operations.
* @param gf The graph used in the computations
* @param n_layers The number of layers in the model.
*/
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf);
// BACKEND V2 // BACKEND V2
struct ggml_mpi_device { struct ggml_mpi_device {