mpi : trying to move more MPI stuff into ggml-mpi (WIP) (#2099)

This commit is contained in:
Georgi Gerganov 2023-07-09 14:08:53 +03:00
parent ef61acfbf5
commit 3232db628c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 134 additions and 67 deletions

View file

@ -2,9 +2,11 @@
#include "ggml.h"
#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>
#define UNUSED GGML_UNUSED
struct ggml_mpi_tensor_info {
@ -52,9 +54,8 @@ static void ggml_mpi_compute_forward_recv(
struct ggml_tensor * ggml_mpi_send_tensor(
struct ggml_context * ctx,
struct ggml_tensor *src,
int dst_rank) {
struct ggml_tensor * src,
int dst_rank) {
struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send);
// TODO how/when to free this struct?
@ -67,9 +68,9 @@ struct ggml_tensor * ggml_mpi_send_tensor(
struct ggml_tensor * ggml_mpi_recv_tensor(
struct ggml_context * ctx,
struct ggml_tensor *parent,
struct ggml_tensor *dst,
int src_rank) {
struct ggml_tensor * parent,
struct ggml_tensor * dst,
int src_rank) {
struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv);
// TODO how/when to free this struct?
@ -79,3 +80,58 @@ struct ggml_tensor * ggml_mpi_recv_tensor(
return result;
}
struct ggml_mpi_context {
int mpi_rank;
int mpi_size;
};
void ggml_mpi_backend_init(void) {
MPI_Init(NULL, NULL);
}
void ggml_mpi_backend_free(void) {
MPI_Finalize();
}
struct ggml_mpi_context * ggml_mpi_init(void) {
struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size);
return ctx;
}
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
free(ctx);
}
int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
return ctx->mpi_rank;
}
struct ggml_tensor * ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
int n_embd,
int * n_tokens,
int * n_past,
int * n_threads) {
struct ggml_tensor * res = NULL;
// synchronize the worker node parameters with the root node
MPI_Barrier(MPI_COMM_WORLD);
MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
if (ctx_mpi->mpi_rank > 0) {
res = ggml_mpi_recv_tensor(ctx, NULL,
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, *n_tokens), ctx_mpi->mpi_rank - 1);
ggml_set_name(res, "mpi_recv");
}
return res;
}