diff --git a/llama.cpp b/llama.cpp index 46318bed3..8e268fa18 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1549,6 +1549,11 @@ struct llama_context { #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; #endif + +#ifdef GGML_USE_OPENSHMEM + ggml_openshmem_context * ctx_oshmem = NULL; +#endif + }; // @@ -6352,6 +6357,12 @@ static int llama_decode_internal( ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); #endif +#if GGML_USE_OPENSHMEM + const int64_t n_layer = hparams.n_layer; + ggml_openshmem_graph_compute_pre(lctx.ctx_oshmem, gf, n_layer); +#endif + + #ifdef GGML_USE_METAL if (lctx.ctx_metal) { ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); @@ -6367,6 +6378,10 @@ static int llama_decode_internal( ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif +#if GGML_USE_OPENSHEM + ggml_openshmem_graph_compute_post(lctx.ctx_oshmem, gf, n_layer); +#endif + // update the kv ring buffer { if (kv_self.has_shift) { @@ -9256,12 +9271,21 @@ void llama_backend_init(bool numa) { #ifdef GGML_USE_MPI ggml_mpi_backend_init(); #endif + +#ifdef GGML_USE_OPENSHMEM + ggml_openshmem_backend_init(); +#endif + } void llama_backend_free(void) { #ifdef GGML_USE_MPI ggml_mpi_backend_free(); #endif +#ifdef GGML_USE_OPENSHMEM + ggml_openshmem_backend_free(); +#endif + } int64_t llama_time_us(void) { @@ -9524,6 +9548,20 @@ struct llama_context * llama_new_context_with_model( } #endif +#ifdef GGML_USE_OPENSHMEM + ctx->ctx_oshmem = ggml_openshmem_init(); + + if (ggml_openshmem_pe(ctx->ctx_oshmem) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); + //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); + exit(1); + } +#endif + return ctx; }