diff --git a/ggml-threading.c b/ggml-threading.c index 4c1bc0714..4a9cf622f 100644 --- a/ggml-threading.c +++ b/ggml-threading.c @@ -261,19 +261,22 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) { PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n", n_worker_threads); - ggml_spin_lock(&ctx->shared.spin); - ctx->shared.wait_now = true; + struct ggml_compute_state_shared *shared = &ctx->shared; + + ggml_spin_lock(&shared->spin); + shared->wait_now = true; ggml_spin_unlock(&ctx->shared.spin); const int n_worker_threads = ctx->n_threads - 1; - while (ctx->shared.n_waiting != n_worker_threads) { + while (shared->n_waiting != n_worker_threads) { ggml_spin_pause(); } PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads); - ggml_spin_lock(&ctx->shared.spin); + + ggml_spin_lock(&shared->spin); ctx->suspending = true; - ggml_spin_unlock(&ctx->shared.spin); + ggml_spin_unlock(&shared->spin); } // Wakeup all workers. @@ -296,44 +299,52 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) { perf_time_0 = ggml_time_us(); } - int loop_counter = 0; - int64_t last_signal_time = 0; + // Dead lock detection. + int counter = 0; + int64_t last_notify_ms = 0; + const int max_notify_count = ctx->n_threads - 1; + const int max_duration_ms = 100 * max_notify_count; ggml_spin_lock(&shared->spin); shared->wait_now = false; while (shared->n_waiting != 0) { - ggml_spin_unlock(&shared->spin); - - if (loop_counter > 0) { - ggml_spin_pause(); - if (loop_counter > 3) { - sched_yield(); - } - } - ++loop_counter; - - // TODO: should bench actual average wait/wakeup time. - if (last_signal_time > 0 && (ggml_time_us() - last_signal_time) < 10) { - continue; - } GGML_ASSERT(pthread_mutex_lock(&shared->mutex) == 0); + if (shared->n_waiting == 0) { + GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0); + ggml_spin_unlock(&shared->spin); + break; + } + + ggml_spin_unlock(&shared->spin); + GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0); GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0); - last_signal_time = ggml_time_us(); + last_notify_ms = ggml_time_ms(); + + sched_yield(); + + int elapsed = last_notify_ms > 0 ? ggml_time_ms() - last_notify_ms : 0; + + if ((counter > max_notify_count) || elapsed > max_duration_ms) { + fprintf(stderr, + "[ggml-threading] potential dead lock detected, notified " + "for %d times, elapsed time: %d ms, abort\n", + counter, elapsed); + abort(); + } ggml_spin_lock(&shared->spin); } ctx->suspending = false; + ggml_spin_unlock(&shared->spin); if (shared->ctx->features & GGML_THREADING_FEATURE_PERF) { ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0, perf_time_0); }; - - ggml_spin_unlock(&shared->spin); } bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) { diff --git a/tests/test-ggml-threading.c b/tests/test-ggml-threading.c index 886c5ee67..8fc705a6b 100644 --- a/tests/test-ggml-threading.c +++ b/tests/test-ggml-threading.c @@ -33,7 +33,7 @@ #define UNUSED(x) (void)(x) -#define MAX_N_THREADS 16 +#define MAX_N_THREADS 64 static const int n_repeat = 10; @@ -353,7 +353,7 @@ int main(void) { // average time, thus greatly punishes those small workloads. // - wait_on_done is general faster than wait_now, can be 10x faster. - int threads_arr[] = {1, 2, 4, 6, 8, 16}; + int threads_arr[] = {1, 2, 4, 6, 8, 16, 32, 64}; int threads_arr_len = sizeof(threads_arr) / sizeof(threads_arr[0]); // millions of loops.