Fix draft thread args and remove grads from mpi eval_init
This commit is contained in:
parent
c9d18263b3
commit
aa166462f1
2 changed files with 35 additions and 16 deletions
|
@ -169,10 +169,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||||
std::vector<std::string> split_arg{it, {}};
|
std::vector<std::string> split_arg{it, {}};
|
||||||
params.n_threads.resize(split_arg.size());
|
params.n_threads.resize(split_arg.size());
|
||||||
for (size_t i = 0; i < split_arg.size(); ++i) {
|
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||||
params.n_threads[i] = std::stoi(split_arg[i]);
|
params.n_threads[node] = std::stoi(split_arg[node]);
|
||||||
if (params.n_threads[i] <= 0) {
|
if (params.n_threads[node] <= 0) {
|
||||||
params.n_threads[i] = std::thread::hardware_concurrency();
|
params.n_threads[node] = std::thread::hardware_concurrency();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,10 +188,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||||
std::vector<std::string> split_arg{it, {}};
|
std::vector<std::string> split_arg{it, {}};
|
||||||
params.n_threads_batch.resize(split_arg.size());
|
params.n_threads_batch.resize(split_arg.size());
|
||||||
for (size_t i = 0; i < split_arg.size(); ++i) {
|
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||||
params.n_threads_batch[i] = std::stoi(split_arg[i]);
|
params.n_threads_batch[node] = std::stoi(split_arg[node]);
|
||||||
if (params.n_threads_batch[i] <= 0) {
|
if (params.n_threads_batch[node] <= 0) {
|
||||||
params.n_threads_batch[i] = std::thread::hardware_concurrency();
|
params.n_threads_batch[node] = std::thread::hardware_concurrency();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (arg == "-td" || arg == "--threads-draft") {
|
} else if (arg == "-td" || arg == "--threads-draft") {
|
||||||
|
@ -199,18 +199,36 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_threads_draft = std::stoi(argv[i]);
|
std::string arg_next = argv[i];
|
||||||
if (params.n_threads_draft <= 0) {
|
|
||||||
params.n_threads_draft = std::thread::hardware_concurrency();
|
// split string by , and /
|
||||||
|
const std::regex regex{R"([,/]+)"};
|
||||||
|
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||||
|
std::vector<std::string> split_arg{it, {}};
|
||||||
|
params.n_threads_draft.resize(split_arg.size());
|
||||||
|
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||||
|
params.n_threads_draft[node] = std::stoi(split_arg[node]);
|
||||||
|
if (params.n_threads_draft[node] <= 0) {
|
||||||
|
params.n_threads_draft[node] = std::thread::hardware_concurrency();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (arg == "-tbd" || arg == "--threads-batch-draft") {
|
} else if (arg == "-tbd" || arg == "--threads-batch-draft") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_threads_batch_draft = std::stoi(argv[i]);
|
std::string arg_next = argv[i];
|
||||||
if (params.n_threads_batch_draft <= 0) {
|
|
||||||
params.n_threads_batch_draft = std::thread::hardware_concurrency();
|
// split string by , and /
|
||||||
|
const std::regex regex{R"([,/]+)"};
|
||||||
|
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||||
|
std::vector<std::string> split_arg{it, {}};
|
||||||
|
params.n_threads_batch_draft.resize(split_arg.size());
|
||||||
|
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||||
|
params.n_threads_batch_draft[node] = std::stoi(split_arg[node]);
|
||||||
|
if (params.n_threads_batch_draft[node] <= 0) {
|
||||||
|
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (arg == "-p" || arg == "--prompt") {
|
} else if (arg == "-p" || arg == "--prompt") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
|
|
|
@ -261,7 +261,9 @@ void ggml_mpi_graph_compute_pre(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(inp0 == gf->nodes[0]);
|
// fprintf(stderr, "gf->nodes[0] == %s\n", ggml_get_name(gf->nodes[0]));
|
||||||
|
//
|
||||||
|
// GGML_ASSERT(inp0 == gf->nodes[0]);
|
||||||
|
|
||||||
// distribute the compute graph into slices across the MPI nodes
|
// distribute the compute graph into slices across the MPI nodes
|
||||||
//
|
//
|
||||||
|
@ -333,7 +335,6 @@ void ggml_mpi_graph_compute_pre(
|
||||||
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
|
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
|
||||||
for (int i = 1; i < idx_l1 - idx_l0; i++) {
|
for (int i = 1; i < idx_l1 - idx_l0; i++) {
|
||||||
gf->nodes[i] = gf->nodes[idx_l0 + i];
|
gf->nodes[i] = gf->nodes[idx_l0 + i];
|
||||||
gf->grads[i] = gf->grads[idx_l0 + i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
|
// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue