update shared state n_threads in parallel region

This commit is contained in:
slaren 2024-05-30 09:47:29 +02:00
parent 7918ed7f2c
commit fa864af945

28
ggml.c
View file

@ -1751,7 +1751,7 @@ struct ggml_compute_state_shared {
int64_t perf_node_start_cycles;
int64_t perf_node_start_time_us;
const int n_threads;
int n_threads;
// synchronization primitives
atomic_int n_active; // num active threads
@ -19486,12 +19486,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
}
#if defined(GGML_USE_OPENMP)
// Limit the number of threads used to avoid deadlock
// ref: https://github.com/ggerganov/llama.cpp/pull/7606
n_threads = MIN(n_threads, omp_get_max_threads());
n_threads = MIN(n_threads, omp_get_thread_limit());
#endif
size_t work_size = 0;
@ -19676,10 +19670,21 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
enum ggml_status compute_status = GGML_STATUS_SUCCESS;
#ifdef GGML_USE_OPENMP
#pragma omp parallel num_threads(n_threads)
if (n_threads > 1) {
#pragma omp parallel num_threads(n_threads)
{
#pragma omp single
{
// update the number of threads from the actual number of threads that we got from OpenMP
n_threads = omp_get_num_threads();
workers[0].shared->n_threads = n_threads;
workers[0].shared->n_active = n_threads;
}
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
}
} else {
ggml_graph_compute_thread(&workers[0]);
}
#else
// create thread pool
if (n_threads > 1) {
@ -19724,7 +19729,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
}
}
const int n_threads = cplan->n_threads;
int n_threads = cplan->n_threads;
#if defined(GGML_USE_OPENMP)
n_threads = MIN(n_threads, omp_get_max_threads());
n_threads = MIN(n_threads, omp_get_thread_limit());
#endif
struct ggml_compute_state_shared state_shared = {
/*.cgraph =*/ cgraph,