threading: fix deadlock by reverting part of changes from commit 286c5b30

This commit is contained in:
mqy 2023-06-19 16:17:48 +08:00
parent 4d32b4088e
commit cc8a375bc4
2 changed files with 14 additions and 6 deletions

View file

@ -199,8 +199,6 @@ struct ggml_threading_context {
int64_t *stages_time; int64_t *stages_time;
}; };
// NOTE: ggml_spin_lock and ggml_spin_unlock may can be noop if
// feature wait_on_done is off.
static inline void ggml_spin_lock(volatile atomic_flag *obj) { static inline void ggml_spin_lock(volatile atomic_flag *obj) {
while (atomic_flag_test_and_set(obj)) { while (atomic_flag_test_and_set(obj)) {
ggml_spin_pause(); ggml_spin_pause();
@ -262,7 +260,10 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) {
PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n", PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n",
n_worker_threads); n_worker_threads);
ggml_spin_lock(&ctx->shared.spin);
ctx->shared.wait_now = true; ctx->shared.wait_now = true;
ggml_spin_unlock(&ctx->shared.spin);
const int n_worker_threads = ctx->n_threads - 1; const int n_worker_threads = ctx->n_threads - 1;
while (ctx->shared.n_waiting != n_worker_threads) { while (ctx->shared.n_waiting != n_worker_threads) {
@ -270,7 +271,9 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) {
} }
PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads); PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads);
ggml_spin_lock(&ctx->shared.spin);
ctx->suspending = true; ctx->suspending = true;
ggml_spin_unlock(&ctx->shared.spin);
} }
// Wakeup all workers. // Wakeup all workers.
@ -278,8 +281,6 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) {
// Workers takes some time to wakeup, and has to lock spin after wakeup. Yield // Workers takes some time to wakeup, and has to lock spin after wakeup. Yield
// is used to avoid signal frequently. Current implementation is highly // is used to avoid signal frequently. Current implementation is highly
// experimental. See tests/test-ggml-threading.c for details. // experimental. See tests/test-ggml-threading.c for details.
//
// NOTE: must be protected by shared->spin
void ggml_threading_resume(struct ggml_threading_context *ctx) { void ggml_threading_resume(struct ggml_threading_context *ctx) {
if (ctx->n_threads == 1) { if (ctx->n_threads == 1) {
return; return;
@ -298,9 +299,12 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) {
int loop_counter = 0; int loop_counter = 0;
int64_t last_signal_time = 0; int64_t last_signal_time = 0;
ggml_spin_lock(&shared->spin);
shared->wait_now = false; shared->wait_now = false;
while (shared->n_waiting != 0) { while (shared->n_waiting != 0) {
ggml_spin_unlock(&shared->spin);
if (loop_counter > 0) { if (loop_counter > 0) {
ggml_spin_pause(); ggml_spin_pause();
if (loop_counter > 3) { if (loop_counter > 3) {
@ -318,6 +322,8 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) {
GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0); GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0);
GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0); GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0);
last_signal_time = ggml_time_us(); last_signal_time = ggml_time_us();
ggml_spin_lock(&shared->spin);
} }
ctx->suspending = false; ctx->suspending = false;
@ -326,6 +332,8 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) {
ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0, ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0,
perf_time_0); perf_time_0);
}; };
ggml_spin_unlock(&shared->spin);
} }
bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) { bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) {

View file

@ -550,13 +550,13 @@ int main(void) {
// lifecycle. // lifecycle.
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
bool wait_on_done = (i == 1); bool wait_on_done = (i == 1);
printf("[test-ggml-threading] test lifecycle (want_on_done = %d) ...\n", printf("[test-ggml-threading] test lifecycle (wait_on_done = %d) ...\n",
wait_on_done); wait_on_done);
++n_tests; ++n_tests;
if (test_lifecycle(wait_on_done) == 0) { if (test_lifecycle(wait_on_done) == 0) {
++n_passed; ++n_passed;
printf("[test-ggml-threading] test lifecycle (want_on_done = %d): " printf("[test-ggml-threading] test lifecycle (wait_on_done = %d): "
"ok\n\n", "ok\n\n",
wait_on_done); wait_on_done);
} }