Fix draft thread args and remove grads from mpi eval_init
This commit is contained in:
parent
c9d18263b3
commit
aa166462f1
2 changed files with 35 additions and 16 deletions
|
@ -169,10 +169,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||
std::vector<std::string> split_arg{it, {}};
|
||||
params.n_threads.resize(split_arg.size());
|
||||
for (size_t i = 0; i < split_arg.size(); ++i) {
|
||||
params.n_threads[i] = std::stoi(split_arg[i]);
|
||||
if (params.n_threads[i] <= 0) {
|
||||
params.n_threads[i] = std::thread::hardware_concurrency();
|
||||
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||
params.n_threads[node] = std::stoi(split_arg[node]);
|
||||
if (params.n_threads[node] <= 0) {
|
||||
params.n_threads[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -188,10 +188,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||
std::vector<std::string> split_arg{it, {}};
|
||||
params.n_threads_batch.resize(split_arg.size());
|
||||
for (size_t i = 0; i < split_arg.size(); ++i) {
|
||||
params.n_threads_batch[i] = std::stoi(split_arg[i]);
|
||||
if (params.n_threads_batch[i] <= 0) {
|
||||
params.n_threads_batch[i] = std::thread::hardware_concurrency();
|
||||
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||
params.n_threads_batch[node] = std::stoi(split_arg[node]);
|
||||
if (params.n_threads_batch[node] <= 0) {
|
||||
params.n_threads_batch[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
} else if (arg == "-td" || arg == "--threads-draft") {
|
||||
|
@ -199,18 +199,36 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_threads_draft = std::stoi(argv[i]);
|
||||
if (params.n_threads_draft <= 0) {
|
||||
params.n_threads_draft = std::thread::hardware_concurrency();
|
||||
std::string arg_next = argv[i];
|
||||
|
||||
// split string by , and /
|
||||
const std::regex regex{R"([,/]+)"};
|
||||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||
std::vector<std::string> split_arg{it, {}};
|
||||
params.n_threads_draft.resize(split_arg.size());
|
||||
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||
params.n_threads_draft[node] = std::stoi(split_arg[node]);
|
||||
if (params.n_threads_draft[node] <= 0) {
|
||||
params.n_threads_draft[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
} else if (arg == "-tbd" || arg == "--threads-batch-draft") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_threads_batch_draft = std::stoi(argv[i]);
|
||||
if (params.n_threads_batch_draft <= 0) {
|
||||
params.n_threads_batch_draft = std::thread::hardware_concurrency();
|
||||
std::string arg_next = argv[i];
|
||||
|
||||
// split string by , and /
|
||||
const std::regex regex{R"([,/]+)"};
|
||||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||
std::vector<std::string> split_arg{it, {}};
|
||||
params.n_threads_batch_draft.resize(split_arg.size());
|
||||
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||
params.n_threads_batch_draft[node] = std::stoi(split_arg[node]);
|
||||
if (params.n_threads_batch_draft[node] <= 0) {
|
||||
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
} else if (arg == "-p" || arg == "--prompt") {
|
||||
if (++i >= argc) {
|
||||
|
|
|
@ -261,7 +261,9 @@ void ggml_mpi_graph_compute_pre(
|
|||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(inp0 == gf->nodes[0]);
|
||||
// fprintf(stderr, "gf->nodes[0] == %s\n", ggml_get_name(gf->nodes[0]));
|
||||
//
|
||||
// GGML_ASSERT(inp0 == gf->nodes[0]);
|
||||
|
||||
// distribute the compute graph into slices across the MPI nodes
|
||||
//
|
||||
|
@ -333,7 +335,6 @@ void ggml_mpi_graph_compute_pre(
|
|||
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
|
||||
for (int i = 1; i < idx_l1 - idx_l0; i++) {
|
||||
gf->nodes[i] = gf->nodes[idx_l0 + i];
|
||||
gf->grads[i] = gf->grads[idx_l0 + i];
|
||||
}
|
||||
|
||||
// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue