update shared state n_threads in parallel region
This commit is contained in:
parent
7918ed7f2c
commit
fa864af945
1 changed files with 21 additions and 11 deletions
32
ggml.c
32
ggml.c
|
@ -1751,7 +1751,7 @@ struct ggml_compute_state_shared {
|
|||
int64_t perf_node_start_cycles;
|
||||
int64_t perf_node_start_time_us;
|
||||
|
||||
const int n_threads;
|
||||
int n_threads;
|
||||
|
||||
// synchronization primitives
|
||||
atomic_int n_active; // num active threads
|
||||
|
@ -19486,12 +19486,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|||
if (n_threads <= 0) {
|
||||
n_threads = GGML_DEFAULT_N_THREADS;
|
||||
}
|
||||
#if defined(GGML_USE_OPENMP)
|
||||
// Limit the number of threads used to avoid deadlock
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/7606
|
||||
n_threads = MIN(n_threads, omp_get_max_threads());
|
||||
n_threads = MIN(n_threads, omp_get_thread_limit());
|
||||
#endif
|
||||
|
||||
size_t work_size = 0;
|
||||
|
||||
|
@ -19676,9 +19670,20 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
|
|||
enum ggml_status compute_status = GGML_STATUS_SUCCESS;
|
||||
|
||||
#ifdef GGML_USE_OPENMP
|
||||
#pragma omp parallel num_threads(n_threads)
|
||||
{
|
||||
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
|
||||
if (n_threads > 1) {
|
||||
#pragma omp parallel num_threads(n_threads)
|
||||
{
|
||||
#pragma omp single
|
||||
{
|
||||
// update the number of threads from the actual number of threads that we got from OpenMP
|
||||
n_threads = omp_get_num_threads();
|
||||
workers[0].shared->n_threads = n_threads;
|
||||
workers[0].shared->n_active = n_threads;
|
||||
}
|
||||
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
|
||||
}
|
||||
} else {
|
||||
ggml_graph_compute_thread(&workers[0]);
|
||||
}
|
||||
#else
|
||||
// create thread pool
|
||||
|
@ -19724,7 +19729,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
|
|||
}
|
||||
}
|
||||
|
||||
const int n_threads = cplan->n_threads;
|
||||
int n_threads = cplan->n_threads;
|
||||
|
||||
#if defined(GGML_USE_OPENMP)
|
||||
n_threads = MIN(n_threads, omp_get_max_threads());
|
||||
n_threads = MIN(n_threads, omp_get_thread_limit());
|
||||
#endif
|
||||
|
||||
struct ggml_compute_state_shared state_shared = {
|
||||
/*.cgraph =*/ cgraph,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue