Fix incorrect sched hash size, refactor new cmdline params to align with new style
This commit is contained in:
parent
cc551dfdfe
commit
be63161d04
2 changed files with 25 additions and 15 deletions
|
@ -184,6 +184,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
|
|||
params.n_threads[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
if (arg == "-tb" || arg == "--threads-batch") {
|
||||
|
@ -204,6 +205,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
|
|||
params.n_threads_batch[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "-td" || arg == "--threads-draft") {
|
||||
if (++i >= argc) {
|
||||
|
@ -223,6 +225,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
|
|||
params.n_threads_draft[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "-tbd" || arg == "--threads-batch-draft") {
|
||||
if (++i >= argc) {
|
||||
|
@ -242,6 +245,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
|
|||
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "-p" || arg == "--prompt") {
|
||||
if (++i >= argc) {
|
||||
|
@ -924,6 +928,7 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
|
|||
for (size_t node = 0; node < split_arg.size(); ++node) {
|
||||
params.mpi_layer_split[node] = std::stof(split_arg[node]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (arg == "--tensor-split" || arg == "-ts") {
|
||||
|
|
|
@ -513,6 +513,8 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
|
|||
}
|
||||
}
|
||||
|
||||
size_t n_srcs = 0;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (cgraph->nodes[i]->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
cgraph->nodes[i]->buffer = ggml_backend_mpi_buffer_unwrap(cgraph->nodes[i]->buffer);
|
||||
|
@ -523,6 +525,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
|
|||
break;
|
||||
}
|
||||
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
n_srcs++;
|
||||
src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
|
||||
}
|
||||
}
|
||||
|
@ -532,6 +535,7 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
|
|||
if (src->buffer->buft != nullptr) {
|
||||
|
||||
if (src->buffer->buft->iface.get_name == ggml_backend_mpi_buffer_type_name) {
|
||||
n_srcs++;
|
||||
src->buffer = ggml_backend_mpi_buffer_unwrap(src->buffer);
|
||||
}
|
||||
}
|
||||
|
@ -546,12 +550,13 @@ GGML_CALL static enum ggml_status ggml_backend_mpi_graph_compute(ggml_backend_t
|
|||
}
|
||||
|
||||
|
||||
|
||||
if (!ctx->remote) {
|
||||
ggml_backend_sched_t sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(),
|
||||
(int) ctx->backends.size(), cgraph->n_nodes, false);
|
||||
(int) ctx->backends.size(), cgraph->n_nodes + cgraph->n_leafs + n_srcs, false);
|
||||
|
||||
ggml_backend_sched_reserve(sched, cgraph);
|
||||
ggml_backend_sched_graph_compute(sched, cgraph);
|
||||
ggml_backend_sched_graph_compute_async(sched, cgraph);
|
||||
ggml_backend_sched_free(sched);
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue