added oshmem backend to llama.cpp
This commit is contained in:
parent
eb0f775950
commit
71f4c96331
1 changed files with 38 additions and 0 deletions
38
llama.cpp
38
llama.cpp
|
@ -1549,6 +1549,11 @@ struct llama_context {
|
|||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_context * ctx_mpi = NULL;
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENSHMEM
|
||||
ggml_openshmem_context * ctx_oshmem = NULL;
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
|
@ -6352,6 +6357,12 @@ static int llama_decode_internal(
|
|||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
|
||||
#if GGML_USE_OPENSHMEM
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
ggml_openshmem_graph_compute_pre(lctx.ctx_oshmem, gf, n_layer);
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (lctx.ctx_metal) {
|
||||
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
||||
|
@ -6367,6 +6378,10 @@ static int llama_decode_internal(
|
|||
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
|
||||
#if GGML_USE_OPENSHEM
|
||||
ggml_openshmem_graph_compute_post(lctx.ctx_oshmem, gf, n_layer);
|
||||
#endif
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
if (kv_self.has_shift) {
|
||||
|
@ -9256,12 +9271,21 @@ void llama_backend_init(bool numa) {
|
|||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_backend_init();
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENSHMEM
|
||||
ggml_openshmem_backend_init();
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void llama_backend_free(void) {
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_backend_free();
|
||||
#endif
|
||||
#ifdef GGML_USE_OPENSHMEM
|
||||
ggml_openshmem_backend_free();
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
int64_t llama_time_us(void) {
|
||||
|
@ -9524,6 +9548,20 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENSHMEM
|
||||
ctx->ctx_oshmem = ggml_openshmem_init();
|
||||
|
||||
if (ggml_openshmem_pe(ctx->ctx_oshmem) > 0) {
|
||||
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
|
||||
// TODO: needs fix after #3228
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
//const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx));
|
||||
//while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {};
|
||||
llama_backend_free();
|
||||
exit(1);
|
||||
}
|
||||
#endif
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue