threadpool: reduce the number of barrier required

New work is now indicated with an atomic counter that is incremented for
each new graph that needs to be computed.
This removes the need for extra barrier for clearing the "new_work" and
removes the special case for trivial graphs.
This commit is contained in:
Max Krasnyansky 2024-08-12 19:04:01 -07:00 committed by fmz
parent b630acdb73
commit 9d3e78c6b8

View file

@ -1959,13 +1959,13 @@ struct ggml_compute_threadpool {
struct ggml_cplan * cplan; struct ggml_cplan * cplan;
// synchronization primitives // synchronization primitives
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
atomic_int n_barrier; atomic_int n_barrier;
atomic_int n_barrier_passed; atomic_int n_barrier_passed;
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
volatile bool stop; // Used for stopping the threadpool altogether volatile bool stop; // Used for stopping the threadpool altogether
volatile bool pause; // Used for pausing the threadpool or individual threads volatile bool pause; // Used for pausing the threadpool or individual threads
volatile bool new_work; // Set when there is work to be done, unset after it's done
struct ggml_compute_state * workers; // per thread state struct ggml_compute_state * workers; // per thread state
int32_t n_threads_max; // number of threads in the pool int32_t n_threads_max; // number of threads in the pool
@ -1987,6 +1987,8 @@ struct ggml_compute_state {
ggml_thread_t thrd; ggml_thread_t thrd;
bool cpumask[GGML_MAX_N_THREADS]; bool cpumask[GGML_MAX_N_THREADS];
bool mask_specified; bool mask_specified;
int last_graph;
bool pending;
#endif #endif
struct ggml_compute_threadpool * threadpool; struct ggml_compute_threadpool * threadpool;
int ith; int ith;
@ -19118,55 +19120,39 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/*.threadpool=*/ state->threadpool, /*.threadpool=*/ state->threadpool,
}; };
struct ggml_tensor * node = cgraph->nodes[0]; for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->threadpool->ec = GGML_STATUS_ABORTED; state->threadpool->ec = GGML_STATUS_ABORTED;
} }
for (int node_n = 1; node_n < cgraph->n_nodes; node_n++) {
ggml_barrier(state->threadpool); ggml_barrier(state->threadpool);
if (state->threadpool->ec != GGML_STATUS_SUCCESS) { if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
break; break;
} }
node = cgraph->nodes[node_n];
ggml_compute_forward(&params, node);
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->threadpool->ec = GGML_STATUS_ABORTED;
} }
}
if (cgraph->n_nodes == 1) {
// We need a barrier before disabling new_work in case we have a trivial graph
ggml_barrier(state->threadpool);
}
if (!state->threadpool->disposable && state->ith == 0) {
// Don't need a lock, because there is a barrier after this, and only after that
// do the secondary threads go into standby
state->threadpool->new_work = false;
}
ggml_barrier(state->threadpool);
return 0; return 0;
} }
#ifndef GGML_USE_OPENMP #ifndef GGML_USE_OPENMP
static inline bool ggml_graph_compute_got_work(struct ggml_compute_state *state) {
struct ggml_compute_threadpool * threadpool = state->threadpool;
return (threadpool->new_work && state->ith < threadpool->n_threads_cur);
}
static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) { static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) {
struct ggml_compute_threadpool * threadpool = state->threadpool; struct ggml_compute_threadpool * threadpool = state->threadpool;
if (threadpool->stop || threadpool->pause) return true; if (threadpool->stop || threadpool->pause || state->pending) { return true; }
return ggml_graph_compute_got_work(state);
// check for new graph/work
int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
if (new_graph != state->last_graph) {
state->pending = (state->ith < threadpool->n_threads_cur);
state->last_graph = new_graph;
}
return state->pending;
} }
static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) { static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
@ -19181,14 +19167,14 @@ static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state *
__cpu_relax(); __cpu_relax();
} }
return ggml_graph_compute_got_work(state); return state->pending;
} }
static bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) { static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) {
struct ggml_compute_threadpool * threadpool = state->threadpool; struct ggml_compute_threadpool * threadpool = state->threadpool;
if (ggml_graph_compute_poll_for_work(state)) { if (ggml_graph_compute_poll_for_work(state)) {
return ggml_graph_compute_got_work(state); return state->pending;
} }
ggml_mutex_lock_shared(&threadpool->mutex); ggml_mutex_lock_shared(&threadpool->mutex);
@ -19199,7 +19185,7 @@ static bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state)
} }
ggml_mutex_unlock_shared(&threadpool->mutex); ggml_mutex_unlock_shared(&threadpool->mutex);
return ggml_graph_compute_got_work(state); return state->pending;
} }
static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
@ -19229,8 +19215,10 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
// Check if there is new work // Check if there is new work
// The main thread is the only one that can dispatch new work // The main thread is the only one that can dispatch new work
bool new_work = ggml_graph_compute_check_for_work(state); ggml_graph_compute_check_for_work(state);
if (new_work) { if (state->pending) {
state->pending = false;
int64_t ret = (int64_t) ggml_graph_compute_thread(state); int64_t ret = (int64_t) ggml_graph_compute_thread(state);
if (ret == GGML_EXIT_ABORTED) if (ret == GGML_EXIT_ABORTED)
return (thread_ret_t) ret; return (thread_ret_t) ret;
@ -19271,12 +19259,12 @@ static struct ggml_compute_threadpool * ggml_create_threadpool_impl(
{ {
threadpool->cgraph = cgraph; threadpool->cgraph = cgraph;
threadpool->cplan = cplan; threadpool->cplan = cplan;
threadpool->n_graph = 0;
threadpool->n_barrier = 0; threadpool->n_barrier = 0;
threadpool->n_barrier_passed = 0; threadpool->n_barrier_passed = 0;
threadpool->current_chunk = 0; threadpool->current_chunk = 0;
threadpool->stop = false; threadpool->stop = false;
threadpool->pause = disposable ? false : tpp->paused; threadpool->pause = disposable ? false : tpp->paused;
threadpool->new_work = false;
threadpool->workers = NULL; threadpool->workers = NULL;
threadpool->n_threads_max = tpp->n_threads; threadpool->n_threads_max = tpp->n_threads;
threadpool->n_threads_cur = disposable ? tpp->n_threads : 0; threadpool->n_threads_cur = disposable ? tpp->n_threads : 0;
@ -19319,7 +19307,9 @@ static struct ggml_compute_threadpool * ggml_create_threadpool_impl(
.thrd = 0, .thrd = 0,
.mask_specified = tpp->mask_specified, .mask_specified = tpp->mask_specified,
.threadpool = threadpool, .threadpool = threadpool,
.ith = j .ith = j,
.last_graph = 0,
.pending = false
}; };
if (tpp->mask_specified) { if (tpp->mask_specified) {
@ -19422,12 +19412,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
// always take the mutex here because the worker threads are doing hybrid poll/wait // always take the mutex here because the worker threads are doing hybrid poll/wait
ggml_mutex_lock(&threadpool->mutex); ggml_mutex_lock(&threadpool->mutex);
threadpool->new_work = true; atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed);
if (!threadpool->pause) { if (threadpool->pause) {
ggml_cond_broadcast(&threadpool->cond);
} else {
// resume does cond broadcast // resume does cond broadcast
__ggml_resume_threadpool(threadpool); __ggml_resume_threadpool(threadpool);
} else {
ggml_cond_broadcast(&threadpool->cond);
} }
ggml_mutex_unlock(&threadpool->mutex); ggml_mutex_unlock(&threadpool->mutex);
} }