diff --git a/io_uring/sqpoll.c b/io_uring/sqpoll.c index c6bb938ec5ea..46c12afec73e 100644 --- a/io_uring/sqpoll.c +++ b/io_uring/sqpoll.c @@ -458,6 +458,7 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx, return -EINVAL; } if (ctx->flags & IORING_SETUP_SQPOLL) { + struct io_uring_task *tctx; struct task_struct *tsk; struct io_sq_data *sqd; bool attached; @@ -524,8 +525,13 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx, rcu_assign_pointer(sqd->thread, tsk); mutex_unlock(&sqd->lock); + ret = 0; get_task_struct(tsk); - ret = io_uring_alloc_task_context(tsk, ctx); + tctx = io_uring_alloc_task_context(tsk, ctx); + if (!IS_ERR(tctx)) + tsk->io_uring = tctx; + else + ret = PTR_ERR(tctx); wake_up_new_task(tsk); if (ret) goto err; diff --git a/io_uring/tctx.c b/io_uring/tctx.c index 143de8e990eb..e5cef6a8dde0 100644 --- a/io_uring/tctx.c +++ b/io_uring/tctx.c @@ -74,20 +74,20 @@ void __io_uring_free(struct task_struct *tsk) } } -__cold int io_uring_alloc_task_context(struct task_struct *task, - struct io_ring_ctx *ctx) +__cold struct io_uring_task *io_uring_alloc_task_context(struct task_struct *task, + struct io_ring_ctx *ctx) { struct io_uring_task *tctx; int ret; tctx = kzalloc_obj(*tctx); if (unlikely(!tctx)) - return -ENOMEM; + return ERR_PTR(-ENOMEM); ret = percpu_counter_init(&tctx->inflight, 0, GFP_KERNEL); if (unlikely(ret)) { kfree(tctx); - return ret; + return ERR_PTR(ret); } tctx->io_wq = io_init_wq_offload(ctx, task); @@ -95,7 +95,7 @@ __cold int io_uring_alloc_task_context(struct task_struct *task, ret = PTR_ERR(tctx->io_wq); percpu_counter_destroy(&tctx->inflight); kfree(tctx); - return ret; + return ERR_PTR(ret); } tctx->task = task; @@ -103,10 +103,9 @@ __cold int io_uring_alloc_task_context(struct task_struct *task, init_waitqueue_head(&tctx->wait); atomic_set(&tctx->in_cancel, 0); atomic_set(&tctx->inflight_tracked, 0); - task->io_uring = tctx; init_llist_head(&tctx->task_list); init_task_work(&tctx->task_work, tctx_task_work); - return 0; + return tctx; } int __io_uring_add_tctx_node(struct io_ring_ctx *ctx) @@ -116,11 +115,11 @@ int __io_uring_add_tctx_node(struct io_ring_ctx *ctx) int ret; if (unlikely(!tctx)) { - ret = io_uring_alloc_task_context(current, ctx); - if (unlikely(ret)) - return ret; + tctx = io_uring_alloc_task_context(current, ctx); + if (IS_ERR(tctx)) + return PTR_ERR(tctx); - tctx = current->io_uring; + current->io_uring = tctx; if (ctx->int_flags & IO_RING_F_IOWQ_LIMITS_SET) { unsigned int limits[2] = { ctx->iowq_limits[0], ctx->iowq_limits[1], }; diff --git a/io_uring/tctx.h b/io_uring/tctx.h index 608e96de70a2..2310d2a0c46d 100644 --- a/io_uring/tctx.h +++ b/io_uring/tctx.h @@ -6,8 +6,8 @@ struct io_tctx_node { struct io_ring_ctx *ctx; }; -int io_uring_alloc_task_context(struct task_struct *task, - struct io_ring_ctx *ctx); +struct io_uring_task *io_uring_alloc_task_context(struct task_struct *task, + struct io_ring_ctx *ctx); void io_uring_del_tctx_node(unsigned long index); int __io_uring_add_tctx_node(struct io_ring_ctx *ctx); int __io_uring_add_tctx_node_from_submit(struct io_ring_ctx *ctx);