Fix incorrect sched hash size, refactor new cmdline params to align with new style

This commit is contained in:
Branden Butler 2024-03-19 11:02:18 -05:00
parent cc551dfdfe
commit be63161d04
2 changed files with 25 additions and 15 deletions

View file

@ -184,6 +184,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads[node] = std::thread::hardware_concurrency(); params.n_threads[node] = std::thread::hardware_concurrency();
} }
} }
return true;
} }
if (arg == "-tb" || arg == "--threads-batch") { if (arg == "-tb" || arg == "--threads-batch") {
@ -204,6 +205,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads_batch[node] = std::thread::hardware_concurrency(); params.n_threads_batch[node] = std::thread::hardware_concurrency();
} }
} }
return true;
} }
if (arg == "-td" || arg == "--threads-draft") { if (arg == "-td" || arg == "--threads-draft") {
if (++i >= argc) { if (++i >= argc) {
@ -223,6 +225,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads_draft[node] = std::thread::hardware_concurrency(); params.n_threads_draft[node] = std::thread::hardware_concurrency();
} }
} }
return true;
} }
if (arg == "-tbd" || arg == "--threads-batch-draft") { if (arg == "-tbd" || arg == "--threads-batch-draft") {
if (++i >= argc) { if (++i >= argc) {
@ -242,6 +245,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency(); params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
} }
} }
return true;
} }
if (arg == "-p" || arg == "--prompt") { if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) { if (++i >= argc) {
@ -910,20 +914,21 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
return true; return true;
} }
if (arg == "--mpi-layer-split") { if (arg == "--mpi-layer-split") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
return true; return true;
} }
std::string arg_next = argv[i]; std::string arg_next = argv[i];
// split string by , and / // split string by , and /
const std::regex regex{R"([,/]+)"}; const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}}; std::vector<std::string> split_arg{it, {}};
params.mpi_layer_split.resize(split_arg.size()); params.mpi_layer_split.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) { for (size_t node = 0; node < split_arg.size(); ++node) {
params.mpi_layer_split[node] = std::stof(split_arg[node]); params.mpi_layer_split[node] = std::stof(split_arg[node]);
} }
return true;
} }
if (arg == "--tensor-split" || arg == "-ts") { if (arg == "--tensor-split" || arg == "-ts") {

View file

@ -513,6 +513,8 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
} }
} }
size_t n_srcs = 0;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
cgraph->nodes[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->nodes[i]->buffer); cgraph->nodes[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->nodes[i]->buffer);
@ -523,6 +525,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
break; break;
} }
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
n_srcs++;
src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer); src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
} }
} }
@ -532,6 +535,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
if (src->buffer->buft != nullptr) { if (src->buffer->buft != nullptr) {
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) { if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
n_srcs++;
src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer); src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
} }
} }
@ -546,12 +550,13 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
} }
if (!ctx->remote) { if (!ctx->remote) {
ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(),
(int) ctx->backends.size(), cgraph->n_nodes, false); (int) ctx->backends.size(), cgraph->n_nodes + cgraph->n_leafs + n_srcs, false);
ggml_backend_sched_reserve(sched, cgraph); ggml_backend_sched_reserve(sched, cgraph);
ggml_backend_sched_graph_compute(sched, cgraph); ggml_backend_sched_graph_compute_async(sched, cgraph);
ggml_backend_sched_free(sched); ggml_backend_sched_free(sched);
} }