Fix draft thread args and remove grads from mpi eval_init

This commit is contained in:
Branden Butler 2024-02-03 13:57:00 -06:00
parent c9d18263b3
commit aa166462f1
2 changed files with 35 additions and 16 deletions

View file

@ -169,10 +169,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}}; std::vector<std::string> split_arg{it, {}};
params.n_threads.resize(split_arg.size()); params.n_threads.resize(split_arg.size());
for (size_t i = 0; i < split_arg.size(); ++i) { for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads[i] = std::stoi(split_arg[i]); params.n_threads[node] = std::stoi(split_arg[node]);
if (params.n_threads[i] <= 0) { if (params.n_threads[node] <= 0) {
params.n_threads[i] = std::thread::hardware_concurrency(); params.n_threads[node] = std::thread::hardware_concurrency();
} }
} }
@ -188,10 +188,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}}; std::vector<std::string> split_arg{it, {}};
params.n_threads_batch.resize(split_arg.size()); params.n_threads_batch.resize(split_arg.size());
for (size_t i = 0; i < split_arg.size(); ++i) { for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_batch[i] = std::stoi(split_arg[i]); params.n_threads_batch[node] = std::stoi(split_arg[node]);
if (params.n_threads_batch[i] <= 0) { if (params.n_threads_batch[node] <= 0) {
params.n_threads_batch[i] = std::thread::hardware_concurrency(); params.n_threads_batch[node] = std::thread::hardware_concurrency();
} }
} }
} else if (arg == "-td" || arg == "--threads-draft") { } else if (arg == "-td" || arg == "--threads-draft") {
@ -199,18 +199,36 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_threads_draft = std::stoi(argv[i]); std::string arg_next = argv[i];
if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency(); // split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_draft.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_draft[node] = std::stoi(split_arg[node]);
if (params.n_threads_draft[node] <= 0) {
params.n_threads_draft[node] = std::thread::hardware_concurrency();
}
} }
} else if (arg == "-tbd" || arg == "--threads-batch-draft") { } else if (arg == "-tbd" || arg == "--threads-batch-draft") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_threads_batch_draft = std::stoi(argv[i]); std::string arg_next = argv[i];
if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency(); // split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_batch_draft.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_batch_draft[node] = std::stoi(split_arg[node]);
if (params.n_threads_batch_draft[node] <= 0) {
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
}
} }
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) { if (++i >= argc) {

View file

@ -261,7 +261,9 @@ void ggml_mpi_graph_compute_pre(
return; return;
} }
GGML_ASSERT(inp0 == gf->nodes[0]); // fprintf(stderr, "gf->nodes[0] == %s\n", ggml_get_name(gf->nodes[0]));
//
// GGML_ASSERT(inp0 == gf->nodes[0]);
// distribute the compute graph into slices across the MPI nodes // distribute the compute graph into slices across the MPI nodes
// //
@ -333,7 +335,6 @@ void ggml_mpi_graph_compute_pre(
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
for (int i = 1; i < idx_l1 - idx_l0; i++) { for (int i = 1; i < idx_l1 - idx_l0; i++) {
gf->nodes[i] = gf->nodes[idx_l0 + i]; gf->nodes[i] = gf->nodes[idx_l0 + i];
gf->grads[i] = gf->grads[idx_l0 + i];
} }
// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node