From c6280bc3f4938e09d059c41bc881f2db8393abf1 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Tue, 12 Mar 2024 12:40:23 -0500 Subject: [PATCH] Update to use backend GUID and changed signatures --- ggml-mpi.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml-mpi.cpp b/ggml-mpi.cpp index a16cc48b7..6ee5168ba 100644 --- a/ggml-mpi.cpp +++ b/ggml-mpi.cpp @@ -304,7 +304,7 @@ void ggml_mpi_graph_creation_post(struct ggml_mpi_context * ctx_mpi, struct ggml for (int i = 0; i < gf->n_nodes; i++) { - gf->nodes[i]->backend = GGML_BACKEND_MPI_SPLIT; + gf->nodes[i]->backend = GGML_BACKEND_TYPE_MPI_SPLIT; } @@ -382,7 +382,7 @@ GGML_CALL static const char * ggml_backend_mpi_buffer_name(ggml_backend_buffer_t GGML_CALL static const char * ggml_backend_mpi_buffer_type_name(ggml_backend_buffer_type_t buft); -GGML_CALL static bool ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { +GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { struct ggml_mpi_context * ctx = (ggml_mpi_context *) backend->context; @@ -511,7 +511,7 @@ GGML_CALL static bool ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggm ctx->backends.size(), cgraph->n_nodes); // printf("Created new scheduler\n"); - ggml_backend_sched_init_measure(sched, cgraph); + ggml_backend_sched_reserve(sched, cgraph); // printf("Beginning sched graph compute\n"); ggml_backend_sched_graph_compute(sched, cgraph); @@ -593,7 +593,7 @@ GGML_CALL static bool ggml_backend_mpi_graph_compute(ggml_backend_t backend, ggm // ggml_mpi_graph_compute_post(ctx, cgraph); - return true; + return GGML_STATUS_SUCCESS; } @@ -914,6 +914,8 @@ void ggml_backend_mpi_set_tensor_async(ggml_backend_t backend, struct ggml_tenso ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t num_backends, int rank) { + static ggml_guid backend_mpi_guid = {0xec, 0x39, 0xce, 0x40, 0xc3, 0x43, 0x49, 0x36, 0x96, 0x03, 0x55, 0x77, 0x5c, 0x1f, 0x44, 0xd3}; + ggml_mpi_context * ctx = ggml_mpi_init(); std::vector wrapped_backends_v; @@ -942,6 +944,7 @@ ggml_backend_t ggml_backend_mpi_init(ggml_backend_t * wrapped_backends, size_t n }; auto *mpi_backend = new ggml_backend { + /* .guid = */ &backend_mpi_guid, /* .interface = */ mpi_backend_i, /* .context = */ ctx, };