diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9bea4e0af..39b9b27dc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1959,13 +1959,13 @@ struct ggml_compute_threadpool { struct ggml_cplan * cplan; // synchronization primitives + atomic_int n_graph; // incremented when there is work to be done (i.e each graph) atomic_int n_barrier; atomic_int n_barrier_passed; atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. volatile bool stop; // Used for stopping the threadpool altogether volatile bool pause; // Used for pausing the threadpool or individual threads - volatile bool new_work; // Set when there is work to be done, unset after it's done struct ggml_compute_state * workers; // per thread state int32_t n_threads_max; // number of threads in the pool @@ -1987,6 +1987,8 @@ struct ggml_compute_state { ggml_thread_t thrd; bool cpumask[GGML_MAX_N_THREADS]; bool mask_specified; + int last_graph; + bool pending; #endif struct ggml_compute_threadpool * threadpool; int ith; @@ -19118,55 +19120,39 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.threadpool=*/ state->threadpool, }; - struct ggml_tensor * node = cgraph->nodes[0]; + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + struct ggml_tensor * node = cgraph->nodes[node_n]; - ggml_compute_forward(¶ms, node); - if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->threadpool->ec = GGML_STATUS_ABORTED; - } - - for (int node_n = 1; node_n < cgraph->n_nodes; node_n++) { - ggml_barrier(state->threadpool); - - if (state->threadpool->ec != GGML_STATUS_SUCCESS) { - break; - } - - node = cgraph->nodes[node_n]; ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { state->threadpool->ec = GGML_STATUS_ABORTED; } - } - if (cgraph->n_nodes == 1) { - // We need a barrier before disabling new_work in case we have a trivial graph ggml_barrier(state->threadpool); - } - if (!state->threadpool->disposable && state->ith == 0) { - // Don't need a lock, because there is a barrier after this, and only after that - // do the secondary threads go into standby - state->threadpool->new_work = false; + if (state->threadpool->ec != GGML_STATUS_SUCCESS) { + break; + } } - ggml_barrier(state->threadpool); - return 0; } #ifndef GGML_USE_OPENMP -static inline bool ggml_graph_compute_got_work(struct ggml_compute_state *state) { - struct ggml_compute_threadpool * threadpool = state->threadpool; - return (threadpool->new_work && state->ith < threadpool->n_threads_cur); -} - static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) { struct ggml_compute_threadpool * threadpool = state->threadpool; - if (threadpool->stop || threadpool->pause) return true; - return ggml_graph_compute_got_work(state); + if (threadpool->stop || threadpool->pause || state->pending) { return true; } + + // check for new graph/work + int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); + if (new_graph != state->last_graph) { + state->pending = (state->ith < threadpool->n_threads_cur); + state->last_graph = new_graph; + } + + return state->pending; } static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) { @@ -19181,14 +19167,14 @@ static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * __cpu_relax(); } - return ggml_graph_compute_got_work(state); + return state->pending; } -static bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) { +static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) { struct ggml_compute_threadpool * threadpool = state->threadpool; if (ggml_graph_compute_poll_for_work(state)) { - return ggml_graph_compute_got_work(state); + return state->pending; } ggml_mutex_lock_shared(&threadpool->mutex); @@ -19199,7 +19185,7 @@ static bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) } ggml_mutex_unlock_shared(&threadpool->mutex); - return ggml_graph_compute_got_work(state); + return state->pending; } static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { @@ -19229,8 +19215,10 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { // Check if there is new work // The main thread is the only one that can dispatch new work - bool new_work = ggml_graph_compute_check_for_work(state); - if (new_work) { + ggml_graph_compute_check_for_work(state); + if (state->pending) { + state->pending = false; + int64_t ret = (int64_t) ggml_graph_compute_thread(state); if (ret == GGML_EXIT_ABORTED) return (thread_ret_t) ret; @@ -19271,12 +19259,12 @@ static struct ggml_compute_threadpool * ggml_create_threadpool_impl( { threadpool->cgraph = cgraph; threadpool->cplan = cplan; + threadpool->n_graph = 0; threadpool->n_barrier = 0; threadpool->n_barrier_passed = 0; threadpool->current_chunk = 0; threadpool->stop = false; threadpool->pause = disposable ? false : tpp->paused; - threadpool->new_work = false; threadpool->workers = NULL; threadpool->n_threads_max = tpp->n_threads; threadpool->n_threads_cur = disposable ? tpp->n_threads : 0; @@ -19319,7 +19307,9 @@ static struct ggml_compute_threadpool * ggml_create_threadpool_impl( .thrd = 0, .mask_specified = tpp->mask_specified, .threadpool = threadpool, - .ith = j + .ith = j, + .last_graph = 0, + .pending = false }; if (tpp->mask_specified) { @@ -19422,12 +19412,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl // always take the mutex here because the worker threads are doing hybrid poll/wait ggml_mutex_lock(&threadpool->mutex); - threadpool->new_work = true; - if (!threadpool->pause) { - ggml_cond_broadcast(&threadpool->cond); - } else { + atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed); + if (threadpool->pause) { // resume does cond broadcast __ggml_resume_threadpool(threadpool); + } else { + ggml_cond_broadcast(&threadpool->cond); } ggml_mutex_unlock(&threadpool->mutex); }