From 9603f7f5bfa017bfa56b0fb4090d0caa4a5bbdc9 Mon Sep 17 00:00:00 2001 From: mqy Date: Fri, 7 Apr 2023 05:18:37 +0800 Subject: [PATCH] ggml: refactor compute thread: merge three spin variables into one --- Makefile | 2 +- ggml.c | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cb14ffdbc..711eea8da 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC +CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC #-DDISABLE_GGML_COMPUTE_SPIN_V2 CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC LDFLAGS = diff --git a/ggml.c b/ggml.c index 8a60bc383..6b44ded9b 100644 --- a/ggml.c +++ b/ggml.c @@ -9239,15 +9239,28 @@ typedef pthread_t ggml_thread_t; #endif +// To rollback quickly, set `-DDISABLE_GGML_COMPUTE_SPIN_V2` to `CFLAGS` in Makefile. +// TODO(mqy): cleanup feature flag DISABLE_GGML_COMPUTE_SPIN_V2. + struct ggml_compute_state_shared { +#ifdef DISABLE_GGML_COMPUTE_SPIN_V2 ggml_lock_t spin; +#endif int n_threads; // synchronization primitives +#ifdef DISABLE_GGML_COMPUTE_SPIN_V2 atomic_int n_ready; atomic_bool has_work; atomic_bool stop; // stop all threads +#else + // The `flag` works as work counter + stop indicator. + // > 0: main thread store initial value, every worker decrease it by 1. + // = 0: all done. + // < 0: stop now. + atomic_int flag; +#endif }; struct ggml_compute_state { @@ -9262,9 +9275,23 @@ struct ggml_compute_state { static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; +#ifdef DISABLE_GGML_COMPUTE_SPIN_V2 const int n_threads = state->shared->n_threads; +#endif while (true) { +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + int flag = atomic_load(&state->shared->flag); + if (flag < 0) return NULL; // stop + if (flag > 0) { // pending works + if (state->node) { // my work + GGML_ASSERT (state->params.ith < state->params.nth); + ggml_compute_forward(&state->params, state->node); + state->node = NULL; + atomic_fetch_sub(&state->shared->flag, 1); // done + } + } +#else if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { atomic_store(&state->shared->has_work, false); } else { @@ -9302,6 +9329,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } else { break; } +#endif } return 0; @@ -9311,19 +9339,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) const int n_threads = cgraph->n_threads; struct ggml_compute_state_shared state_shared = { +#ifdef DISABLE_GGML_COMPUTE_SPIN_V2 /*.spin =*/ GGML_LOCK_INITIALIZER, +#endif /*.n_threads =*/ n_threads, +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + /*.flag =*/ 0, +#else /*.n_ready =*/ 0, /*.has_work =*/ false, /*.stop =*/ false, +#endif }; struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; // create thread pool if (n_threads > 1) { +#ifdef DISABLE_GGML_COMPUTE_SPIN_V2 ggml_lock_init(&state_shared.spin); atomic_store(&state_shared.has_work, true); +#endif for (int j = 0; j < n_threads - 1; j++) { workers[j] = (struct ggml_compute_state) { @@ -9579,6 +9615,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // COMPUTE if (node->n_tasks > 1) { +#ifdef DISABLE_GGML_COMPUTE_SPIN_V2 if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { atomic_store(&state_shared.has_work, false); } @@ -9587,9 +9624,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) ggml_lock_lock (&state_shared.spin); ggml_lock_unlock(&state_shared.spin); } +#endif // launch thread pool +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + for (int j = 0; j < node->n_tasks - 1; j++) { +#else for (int j = 0; j < n_threads - 1; j++) { +#endif workers[j].params = (struct ggml_compute_params) { .type = GGML_TASK_COMPUTE, .ith = j + 1, @@ -9600,6 +9642,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) workers[j].node = node; } +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + atomic_store(&state_shared.flag, node->n_tasks-1); +#else atomic_fetch_sub(&state_shared.n_ready, 1); while (atomic_load(&state_shared.n_ready) > 0) { @@ -9608,6 +9653,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } atomic_store(&state_shared.has_work, true); +#endif } params.type = GGML_TASK_COMPUTE; @@ -9615,6 +9661,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // wait for thread pool if (node->n_tasks > 1) { +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + while (atomic_load(&state_shared.flag) != 0) {} +#else if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { atomic_store(&state_shared.has_work, false); } @@ -9630,10 +9679,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) ggml_lock_lock (&state_shared.spin); ggml_lock_unlock(&state_shared.spin); } +#endif } // FINALIZE if (node->n_tasks > 1) { +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + for (int j = 0; j < node->n_tasks-1; j++) { + workers[j].params.type = GGML_TASK_FINALIZE; + workers[j].node = node; + } + atomic_store(&state_shared.flag, node->n_tasks-1); +#else if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { atomic_store(&state_shared.has_work, false); } @@ -9663,6 +9720,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } atomic_store(&state_shared.has_work, true); +#endif } params.type = GGML_TASK_FINALIZE; @@ -9670,6 +9728,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // wait for thread pool if (node->n_tasks > 1) { +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + while (atomic_load(&state_shared.flag) != 0) {} +#else if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { atomic_store(&state_shared.has_work, false); } @@ -9685,6 +9746,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) ggml_lock_lock (&state_shared.spin); ggml_lock_unlock(&state_shared.spin); } +#endif } // performance stats (node) @@ -9700,8 +9762,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) // join thread pool if (n_threads > 1) { +#ifndef DISABLE_GGML_COMPUTE_SPIN_V2 + atomic_store(&state_shared.flag, -1); +#else atomic_store(&state_shared.stop, true); atomic_store(&state_shared.has_work, true); +#endif for (int j = 0; j < n_threads - 1; j++) { int rc = ggml_thread_join(workers[j].thrd, NULL); @@ -9709,7 +9775,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) UNUSED(rc); } +#ifdef DISABLE_GGML_COMPUTE_SPIN_V2 ggml_lock_destroy(&state_shared.spin); +#endif } // performance stats (graph)