io_uring: always wait for sqd exited when stopping SQPOLL thread

We have a tiny race where io_put_sq_data() calls io_sq_thead_stop()
and finds the thread gone, but the thread has indeed not fully
exited or called complete() yet. Close it up by always having
io_sq_thread_stop() wait on completion of the exit event.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
This commit is contained in:
Jens Axboe 2021-03-09 16:32:13 -07:00
parent 5199328a0d
commit e8f98f2454

View file

@ -7079,12 +7079,9 @@ static void io_sq_thread_stop(struct io_sq_data *sqd)
if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state)) if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
return; return;
down_write(&sqd->rw_lock); down_write(&sqd->rw_lock);
if (!sqd->thread) {
up_write(&sqd->rw_lock);
return;
}
set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state); set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
wake_up_process(sqd->thread); if (sqd->thread)
wake_up_process(sqd->thread);
up_write(&sqd->rw_lock); up_write(&sqd->rw_lock);
wait_for_completion(&sqd->exited); wait_for_completion(&sqd->exited);
} }
@ -7849,9 +7846,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
ret = -EINVAL; ret = -EINVAL;
if (cpu >= nr_cpu_ids) if (cpu >= nr_cpu_ids)
goto err; goto err_sqpoll;
if (!cpu_online(cpu)) if (!cpu_online(cpu))
goto err; goto err_sqpoll;
sqd->sq_cpu = cpu; sqd->sq_cpu = cpu;
} else { } else {
@ -7862,7 +7859,7 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE); tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
if (IS_ERR(tsk)) { if (IS_ERR(tsk)) {
ret = PTR_ERR(tsk); ret = PTR_ERR(tsk);
goto err; goto err_sqpoll;
} }
sqd->thread = tsk; sqd->thread = tsk;
@ -7881,6 +7878,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
err: err:
io_sq_thread_finish(ctx); io_sq_thread_finish(ctx);
return ret; return ret;
err_sqpoll:
complete(&ctx->sq_data->exited);
goto err;
} }
static inline void __io_unaccount_mem(struct user_struct *user, static inline void __io_unaccount_mem(struct user_struct *user,