Fix incorrect sched hash size, refactor new cmdline params to align with new style

This commit is contained in:
Branden Butler 2024-03-19 11:02:18 -05:00
parent cc551dfdfe
commit be63161d04
2 changed files with 25 additions and 15 deletions

View file

@ -184,6 +184,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads[node] = std::thread::hardware_concurrency();
}
}
return true;
}
if (arg == "-tb" || arg == "--threads-batch") {
@ -204,6 +205,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads_batch[node] = std::thread::hardware_concurrency();
}
}
return true;
}
if (arg == "-td" || arg == "--threads-draft") {
if (++i >= argc) {
@ -223,6 +225,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads_draft[node] = std::thread::hardware_concurrency();
}
}
return true;
}
if (arg == "-tbd" || arg == "--threads-batch-draft") {
if (++i >= argc) {
@ -242,6 +245,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
}
}
return true;
}
if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
@ -924,6 +928,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
for (size_t node = 0; node < split_arg.size(); ++node) {
params.mpi_layer_split[node] = std::stof(split_arg[node]);
}
return true;
}
if (arg == "--tensor-split" || arg == "-ts") {

View file

@ -513,6 +513,8 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
}
}
size_t n_srcs = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
cgraph->nodes[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->nodes[i]->buffer);
@ -523,6 +525,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
break;
}
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
n_srcs++;
src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
}
}
@ -532,6 +535,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
if (src->buffer->buft != nullptr) {
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
n_srcs++;
src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
}
}
@ -546,12 +550,13 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
}
if (!ctx->remote) {
ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(),
(int) ctx->backends.size(), cgraph->n_nodes, false);
(int) ctx->backends.size(), cgraph->n_nodes + cgraph->n_leafs + n_srcs, false);
ggml_backend_sched_reserve(sched, cgraph);
ggml_backend_sched_graph_compute(sched, cgraph);
ggml_backend_sched_graph_compute_async(sched, cgraph);
ggml_backend_sched_free(sched);
}