io_uring: protect remaining lockless ctx->rings accesses with RCU

Commit 9618908026 addressed one case of ctx->rings being potentially
accessed while a resize is happening on the ring, but there are still
a few others that need handling. Add a helper for retrieving the
rings associated with an io_uring context, and add some sanity checking
to that to catch bad uses. ->rings_rcu is always valid, as long as it's
used within RCU read lock. Any use of ->rings_rcu or ->rings inside
either ->uring_lock or ->completion_lock is sane as well.

Do the minimum fix for the current kernel, but set it up such that this
basic infra can be extended for later kernels to make this harder to
mess up in the future.

Thanks to Junxi Qian for finding and debugging this issue.

Cc: stable@vger.kernel.org
Fixes: 79cfe9e59c ("io_uring/register: add IORING_REGISTER_RESIZE_RINGS")
Reviewed-by: Junxi Qian <qjx1298677004@gmail.com>
Tested-by: Junxi Qian <qjx1298677004@gmail.com>
Link: https://lore.kernel.org/io-uring/20260330172348.89416-1-qjx1298677004@gmail.com/
Signed-off-by: Jens Axboe <axboe@kernel.dk>
This commit is contained in:
Jens Axboe
2026-03-31 07:07:47 -06:00
parent 111a12b422
commit 61a11cf481
4 changed files with 70 additions and 28 deletions

View File

@@ -2015,7 +2015,7 @@ int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
if (ctx->flags & IORING_SETUP_SQ_REWIND)
entries = ctx->sq_entries;
else
entries = io_sqring_entries(ctx);
entries = __io_sqring_entries(ctx);
entries = min(nr, entries);
if (unlikely(!entries))
@@ -2250,7 +2250,9 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait)
*/
poll_wait(file, &ctx->poll_wq, wait);
if (!io_sqring_full(ctx))
rcu_read_lock();
if (!__io_sqring_full(ctx))
mask |= EPOLLOUT | EPOLLWRNORM;
/*
@@ -2270,6 +2272,7 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait)
if (__io_cqring_events_user(ctx) || io_has_work(ctx))
mask |= EPOLLIN | EPOLLRDNORM;
rcu_read_unlock();
return mask;
}

View File

@@ -142,16 +142,28 @@ struct io_wait_queue {
#endif
};
static inline struct io_rings *io_get_rings(struct io_ring_ctx *ctx)
{
return rcu_dereference_check(ctx->rings_rcu,
lockdep_is_held(&ctx->uring_lock) ||
lockdep_is_held(&ctx->completion_lock));
}
static inline bool io_should_wake(struct io_wait_queue *iowq)
{
struct io_ring_ctx *ctx = iowq->ctx;
int dist = READ_ONCE(ctx->rings->cq.tail) - (int) iowq->cq_tail;
struct io_rings *rings;
int dist;
guard(rcu)();
rings = io_get_rings(ctx);
/*
* Wake up if we have enough events, or if a timeout occurred since we
* started waiting. For timeouts, we always want to return to userspace,
* regardless of event count.
*/
dist = READ_ONCE(rings->cq.tail) - (int) iowq->cq_tail;
return dist >= 0 || atomic_read(&ctx->cq_timeouts) != iowq->nr_timeouts;
}
@@ -431,9 +443,9 @@ static inline void io_cqring_wake(struct io_ring_ctx *ctx)
__io_wq_wake(&ctx->cq_wait);
}
static inline bool io_sqring_full(struct io_ring_ctx *ctx)
static inline bool __io_sqring_full(struct io_ring_ctx *ctx)
{
struct io_rings *r = ctx->rings;
struct io_rings *r = io_get_rings(ctx);
/*
* SQPOLL must use the actual sqring head, as using the cached_sq_head
@@ -445,9 +457,15 @@ static inline bool io_sqring_full(struct io_ring_ctx *ctx)
return READ_ONCE(r->sq.tail) - READ_ONCE(r->sq.head) == ctx->sq_entries;
}
static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
static inline bool io_sqring_full(struct io_ring_ctx *ctx)
{
struct io_rings *rings = ctx->rings;
guard(rcu)();
return __io_sqring_full(ctx);
}
static inline unsigned int __io_sqring_entries(struct io_ring_ctx *ctx)
{
struct io_rings *rings = io_get_rings(ctx);
unsigned int entries;
/* make sure SQ entry isn't read before tail */
@@ -455,6 +473,12 @@ static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
return min(entries, ctx->sq_entries);
}
static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
{
guard(rcu)();
return __io_sqring_entries(ctx);
}
/*
* Don't complete immediately but use deferred completion infrastructure.
* Protected by ->uring_lock and can only be used either with

View File

@@ -79,12 +79,15 @@ static enum hrtimer_restart io_cqring_min_timer_wakeup(struct hrtimer *timer)
if (io_has_work(ctx))
goto out_wake;
/* got events since we started waiting, min timeout is done */
if (iowq->cq_min_tail != READ_ONCE(ctx->rings->cq.tail))
goto out_wake;
/* if we have any events and min timeout expired, we're done */
if (io_cqring_events(ctx))
goto out_wake;
scoped_guard(rcu) {
struct io_rings *rings = io_get_rings(ctx);
if (iowq->cq_min_tail != READ_ONCE(rings->cq.tail))
goto out_wake;
/* if we have any events and min timeout expired, we're done */
if (io_cqring_events(ctx))
goto out_wake;
}
/*
* If using deferred task_work running and application is waiting on
* more than one request, ensure we reset it now where we are switching
@@ -186,9 +189,9 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags,
struct ext_arg *ext_arg)
{
struct io_wait_queue iowq;
struct io_rings *rings = ctx->rings;
struct io_rings *rings;
ktime_t start_time;
int ret;
int ret, nr_wait;
min_events = min_t(int, min_events, ctx->cq_entries);
@@ -201,15 +204,23 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags,
if (unlikely(test_bit(IO_CHECK_CQ_OVERFLOW_BIT, &ctx->check_cq)))
io_cqring_do_overflow_flush(ctx);
if (__io_cqring_events_user(ctx) >= min_events)
rcu_read_lock();
rings = io_get_rings(ctx);
if (__io_cqring_events_user(ctx) >= min_events) {
rcu_read_unlock();
return 0;
}
init_waitqueue_func_entry(&iowq.wq, io_wake_function);
iowq.wq.private = current;
INIT_LIST_HEAD(&iowq.wq.entry);
iowq.ctx = ctx;
iowq.cq_tail = READ_ONCE(ctx->rings->cq.head) + min_events;
iowq.cq_min_tail = READ_ONCE(ctx->rings->cq.tail);
iowq.cq_tail = READ_ONCE(rings->cq.head) + min_events;
iowq.cq_min_tail = READ_ONCE(rings->cq.tail);
nr_wait = (int) iowq.cq_tail - READ_ONCE(rings->cq.tail);
rcu_read_unlock();
rings = NULL;
iowq.nr_timeouts = atomic_read(&ctx->cq_timeouts);
iowq.hit_timeout = 0;
iowq.min_timeout = ext_arg->min_time;
@@ -240,14 +251,6 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags,
trace_io_uring_cqring_wait(ctx, min_events);
do {
unsigned long check_cq;
int nr_wait;
/* if min timeout has been hit, don't reset wait count */
if (!iowq.hit_timeout)
nr_wait = (int) iowq.cq_tail -
READ_ONCE(ctx->rings->cq.tail);
else
nr_wait = 1;
if (ctx->flags & IORING_SETUP_DEFER_TASKRUN) {
atomic_set(&ctx->cq_wait_nr, nr_wait);
@@ -298,11 +301,20 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags,
break;
}
cond_resched();
/* if min timeout has been hit, don't reset wait count */
if (!iowq.hit_timeout)
scoped_guard(rcu)
nr_wait = (int) iowq.cq_tail -
READ_ONCE(io_get_rings(ctx)->cq.tail);
else
nr_wait = 1;
} while (1);
if (!(ctx->flags & IORING_SETUP_DEFER_TASKRUN))
finish_wait(&ctx->cq_wait, &iowq.wq);
restore_saved_sigmask_unless(ret == -EINTR);
return READ_ONCE(rings->cq.head) == READ_ONCE(rings->cq.tail) ? ret : 0;
guard(rcu)();
return READ_ONCE(io_get_rings(ctx)->cq.head) == READ_ONCE(io_get_rings(ctx)->cq.tail) ? ret : 0;
}

View File

@@ -28,12 +28,15 @@ void io_cqring_do_overflow_flush(struct io_ring_ctx *ctx);
static inline unsigned int __io_cqring_events(struct io_ring_ctx *ctx)
{
return ctx->cached_cq_tail - READ_ONCE(ctx->rings->cq.head);
struct io_rings *rings = io_get_rings(ctx);
return ctx->cached_cq_tail - READ_ONCE(rings->cq.head);
}
static inline unsigned int __io_cqring_events_user(struct io_ring_ctx *ctx)
{
return READ_ONCE(ctx->rings->cq.tail) - READ_ONCE(ctx->rings->cq.head);
struct io_rings *rings = io_get_rings(ctx);
return READ_ONCE(rings->cq.tail) - READ_ONCE(rings->cq.head);
}
/*