From cc8a375bc411c153ec0771faf852316aa5e42f83 Mon Sep 17 00:00:00 2001 From: mqy Date: Mon, 19 Jun 2023 16:17:48 +0800 Subject: [PATCH] threading: fix deadlock by reverting part of changes from commit 286c5b30 --- ggml-threading.c | 16 ++++++++++++---- tests/test-ggml-threading.c | 4 ++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/ggml-threading.c b/ggml-threading.c index 2a5cfa096..4c1bc0714 100644 --- a/ggml-threading.c +++ b/ggml-threading.c @@ -199,8 +199,6 @@ struct ggml_threading_context { int64_t *stages_time; }; -// NOTE: ggml_spin_lock and ggml_spin_unlock may can be noop if -// feature wait_on_done is off. static inline void ggml_spin_lock(volatile atomic_flag *obj) { while (atomic_flag_test_and_set(obj)) { ggml_spin_pause(); @@ -262,7 +260,10 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) { PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n", n_worker_threads); + + ggml_spin_lock(&ctx->shared.spin); ctx->shared.wait_now = true; + ggml_spin_unlock(&ctx->shared.spin); const int n_worker_threads = ctx->n_threads - 1; while (ctx->shared.n_waiting != n_worker_threads) { @@ -270,7 +271,9 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) { } PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads); + ggml_spin_lock(&ctx->shared.spin); ctx->suspending = true; + ggml_spin_unlock(&ctx->shared.spin); } // Wakeup all workers. @@ -278,8 +281,6 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) { // Workers takes some time to wakeup, and has to lock spin after wakeup. Yield // is used to avoid signal frequently. Current implementation is highly // experimental. See tests/test-ggml-threading.c for details. -// -// NOTE: must be protected by shared->spin void ggml_threading_resume(struct ggml_threading_context *ctx) { if (ctx->n_threads == 1) { return; @@ -298,9 +299,12 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { int loop_counter = 0; int64_t last_signal_time = 0; + ggml_spin_lock(&shared->spin); shared->wait_now = false; while (shared->n_waiting != 0) { + ggml_spin_unlock(&shared->spin); + if (loop_counter > 0) { ggml_spin_pause(); if (loop_counter > 3) { @@ -318,6 +322,8 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0); GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0); last_signal_time = ggml_time_us(); + + ggml_spin_lock(&shared->spin); } ctx->suspending = false; @@ -326,6 +332,8 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0, perf_time_0); }; + + ggml_spin_unlock(&shared->spin); } bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) { diff --git a/tests/test-ggml-threading.c b/tests/test-ggml-threading.c index 30dddbeab..886c5ee67 100644 --- a/tests/test-ggml-threading.c +++ b/tests/test-ggml-threading.c @@ -550,13 +550,13 @@ int main(void) { // lifecycle. for (int i = 0; i < 2; i++) { bool wait_on_done = (i == 1); - printf("[test-ggml-threading] test lifecycle (want_on_done = %d) ...\n", + printf("[test-ggml-threading] test lifecycle (wait_on_done = %d) ...\n", wait_on_done); ++n_tests; if (test_lifecycle(wait_on_done) == 0) { ++n_passed; - printf("[test-ggml-threading] test lifecycle (want_on_done = %d): " + printf("[test-ggml-threading] test lifecycle (wait_on_done = %d): " "ok\n\n", wait_on_done); }