mpi : trying to move more MPI stuff into ggml-mpi (WIP) (#2099)
This commit is contained in:
parent
ef61acfbf5
commit
3232db628c
11 changed files with 134 additions and 67 deletions
70
ggml-mpi.c
70
ggml-mpi.c
|
@ -2,9 +2,11 @@
|
|||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
struct ggml_mpi_tensor_info {
|
||||
|
@ -52,9 +54,8 @@ static void ggml_mpi_compute_forward_recv(
|
|||
|
||||
struct ggml_tensor * ggml_mpi_send_tensor(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor *src,
|
||||
int dst_rank) {
|
||||
|
||||
struct ggml_tensor * src,
|
||||
int dst_rank) {
|
||||
struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send);
|
||||
|
||||
// TODO how/when to free this struct?
|
||||
|
@ -67,9 +68,9 @@ struct ggml_tensor * ggml_mpi_send_tensor(
|
|||
|
||||
struct ggml_tensor * ggml_mpi_recv_tensor(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor *parent,
|
||||
struct ggml_tensor *dst,
|
||||
int src_rank) {
|
||||
struct ggml_tensor * parent,
|
||||
struct ggml_tensor * dst,
|
||||
int src_rank) {
|
||||
struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv);
|
||||
|
||||
// TODO how/when to free this struct?
|
||||
|
@ -79,3 +80,58 @@ struct ggml_tensor * ggml_mpi_recv_tensor(
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_mpi_context {
|
||||
int mpi_rank;
|
||||
int mpi_size;
|
||||
};
|
||||
|
||||
void ggml_mpi_backend_init(void) {
|
||||
MPI_Init(NULL, NULL);
|
||||
}
|
||||
|
||||
void ggml_mpi_backend_free(void) {
|
||||
MPI_Finalize();
|
||||
}
|
||||
|
||||
struct ggml_mpi_context * ggml_mpi_init(void) {
|
||||
struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));
|
||||
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
|
||||
return ctx->mpi_rank;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_mpi_eval_init(
|
||||
struct ggml_mpi_context * ctx_mpi,
|
||||
struct ggml_context * ctx,
|
||||
int n_embd,
|
||||
int * n_tokens,
|
||||
int * n_past,
|
||||
int * n_threads) {
|
||||
struct ggml_tensor * res = NULL;
|
||||
|
||||
// synchronize the worker node parameters with the root node
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
||||
MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
||||
MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
||||
|
||||
if (ctx_mpi->mpi_rank > 0) {
|
||||
res = ggml_mpi_recv_tensor(ctx, NULL,
|
||||
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, *n_tokens), ctx_mpi->mpi_rank - 1);
|
||||
ggml_set_name(res, "mpi_recv");
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue