diff --git a/common/common.cpp b/common/common.cpp index 39eb7c909..dd742ff25 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -184,6 +184,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int params.n_threads[node] = std::thread::hardware_concurrency(); } } + return true; } if (arg == "-tb" || arg == "--threads-batch") { @@ -204,6 +205,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int params.n_threads_batch[node] = std::thread::hardware_concurrency(); } } + return true; } if (arg == "-td" || arg == "--threads-draft") { if (++i >= argc) { @@ -223,6 +225,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int params.n_threads_draft[node] = std::thread::hardware_concurrency(); } } + return true; } if (arg == "-tbd" || arg == "--threads-batch-draft") { if (++i >= argc) { @@ -242,6 +245,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int params.n_threads_batch_draft[node] = std::thread::hardware_concurrency(); } } + return true; } if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { @@ -910,20 +914,21 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int return true; } if (arg == "--mpi-layer-split") { - if (++i >= argc) { - invalid_param = true; - return true; - } - std::string arg_next = argv[i]; + if (++i >= argc) { + invalid_param = true; + return true; + } + std::string arg_next = argv[i]; - // split string by , and / - const std::regex regex{R"([,/]+)"}; - std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; - std::vector split_arg{it, {}}; - params.mpi_layer_split.resize(split_arg.size()); - for (size_t node = 0; node < split_arg.size(); ++node) { - params.mpi_layer_split[node] = std::stof(split_arg[node]); - } + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.mpi_layer_split.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.mpi_layer_split[node] = std::stof(split_arg[node]); + } + return true; } if (arg == "--tensor-split" || arg == "-ts") { diff --git a/ggml-mpi.cpp b/ggml-mpi.cpp index 6e12e93f5..7fb4e752c 100644 --- a/ggml-mpi.cpp +++ b/ggml-mpi.cpp @@ -513,6 +513,8 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t } } + size_t n_srcs = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { cgraph->nodes[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->nodes[i]->buffer); @@ -523,6 +525,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t break; } if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { + n_srcs++; src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer); } } @@ -532,6 +535,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t if (src->buffer->buft != nullptr) { if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { + n_srcs++; src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer); } } @@ -546,12 +550,13 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t } + if (!ctx->remote) { ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), - (int) ctx->backends.size(), cgraph->n_nodes, false); + (int) ctx->backends.size(), cgraph->n_nodes + cgraph->n_leafs + n_srcs, false); ggml_backend_sched_reserve(sched, cgraph); - ggml_backend_sched_graph_compute(sched, cgraph); + ggml_backend_sched_graph_compute_async(sched, cgraph); ggml_backend_sched_free(sched); }