diff --git a/common/common.cpp b/common/common.cpp index cae17a3d2..46ec366b0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -169,10 +169,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; std::vector split_arg{it, {}}; params.n_threads.resize(split_arg.size()); - for (size_t i = 0; i < split_arg.size(); ++i) { - params.n_threads[i] = std::stoi(split_arg[i]); - if (params.n_threads[i] <= 0) { - params.n_threads[i] = std::thread::hardware_concurrency(); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads[node] = std::stoi(split_arg[node]); + if (params.n_threads[node] <= 0) { + params.n_threads[node] = std::thread::hardware_concurrency(); } } @@ -188,10 +188,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; std::vector split_arg{it, {}}; params.n_threads_batch.resize(split_arg.size()); - for (size_t i = 0; i < split_arg.size(); ++i) { - params.n_threads_batch[i] = std::stoi(split_arg[i]); - if (params.n_threads_batch[i] <= 0) { - params.n_threads_batch[i] = std::thread::hardware_concurrency(); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads_batch[node] = std::stoi(split_arg[node]); + if (params.n_threads_batch[node] <= 0) { + params.n_threads_batch[node] = std::thread::hardware_concurrency(); } } } else if (arg == "-td" || arg == "--threads-draft") { @@ -199,18 +199,36 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.n_threads_draft = std::stoi(argv[i]); - if (params.n_threads_draft <= 0) { - params.n_threads_draft = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads_draft.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads_draft[node] = std::stoi(split_arg[node]); + if (params.n_threads_draft[node] <= 0) { + params.n_threads_draft[node] = std::thread::hardware_concurrency(); + } } } else if (arg == "-tbd" || arg == "--threads-batch-draft") { if (++i >= argc) { invalid_param = true; break; } - params.n_threads_batch_draft = std::stoi(argv[i]); - if (params.n_threads_batch_draft <= 0) { - params.n_threads_batch_draft = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads_batch_draft.resize(split_arg.size()); + for (size_t node = 0; node < split_arg.size(); ++node) { + params.n_threads_batch_draft[node] = std::stoi(split_arg[node]); + if (params.n_threads_batch_draft[node] <= 0) { + params.n_threads_batch_draft[node] = std::thread::hardware_concurrency(); + } } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { diff --git a/ggml-mpi.c b/ggml-mpi.c index fd88eab1f..c10faa252 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -261,7 +261,9 @@ void ggml_mpi_graph_compute_pre( return; } - GGML_ASSERT(inp0 == gf->nodes[0]); +// fprintf(stderr, "gf->nodes[0] == %s\n", ggml_get_name(gf->nodes[0])); +// +// GGML_ASSERT(inp0 == gf->nodes[0]); // distribute the compute graph into slices across the MPI nodes // @@ -333,7 +335,6 @@ void ggml_mpi_graph_compute_pre( // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph for (int i = 1; i < idx_l1 - idx_l0; i++) { gf->nodes[i] = gf->nodes[idx_l0 + i]; - gf->grads[i] = gf->grads[idx_l0 + i]; } // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node