Merge tag 'pull-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/viro/vfs
[sfrench/cifs-2.6.git] / io_uring / tctx.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/mm.h>
6 #include <linux/slab.h>
7 #include <linux/nospec.h>
8 #include <linux/io_uring.h>
9
10 #include <uapi/linux/io_uring.h>
11
12 #include "io_uring.h"
13 #include "tctx.h"
14
15 static struct io_wq *io_init_wq_offload(struct io_ring_ctx *ctx,
16                                         struct task_struct *task)
17 {
18         struct io_wq_hash *hash;
19         struct io_wq_data data;
20         unsigned int concurrency;
21
22         mutex_lock(&ctx->uring_lock);
23         hash = ctx->hash_map;
24         if (!hash) {
25                 hash = kzalloc(sizeof(*hash), GFP_KERNEL);
26                 if (!hash) {
27                         mutex_unlock(&ctx->uring_lock);
28                         return ERR_PTR(-ENOMEM);
29                 }
30                 refcount_set(&hash->refs, 1);
31                 init_waitqueue_head(&hash->wait);
32                 ctx->hash_map = hash;
33         }
34         mutex_unlock(&ctx->uring_lock);
35
36         data.hash = hash;
37         data.task = task;
38         data.free_work = io_wq_free_work;
39         data.do_work = io_wq_submit_work;
40
41         /* Do QD, or 4 * CPUS, whatever is smallest */
42         concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
43
44         return io_wq_create(concurrency, &data);
45 }
46
47 void __io_uring_free(struct task_struct *tsk)
48 {
49         struct io_uring_task *tctx = tsk->io_uring;
50
51         WARN_ON_ONCE(!xa_empty(&tctx->xa));
52         WARN_ON_ONCE(tctx->io_wq);
53         WARN_ON_ONCE(tctx->cached_refs);
54
55         percpu_counter_destroy(&tctx->inflight);
56         kfree(tctx);
57         tsk->io_uring = NULL;
58 }
59
60 __cold int io_uring_alloc_task_context(struct task_struct *task,
61                                        struct io_ring_ctx *ctx)
62 {
63         struct io_uring_task *tctx;
64         int ret;
65
66         tctx = kzalloc(sizeof(*tctx), GFP_KERNEL);
67         if (unlikely(!tctx))
68                 return -ENOMEM;
69
70         ret = percpu_counter_init(&tctx->inflight, 0, GFP_KERNEL);
71         if (unlikely(ret)) {
72                 kfree(tctx);
73                 return ret;
74         }
75
76         tctx->io_wq = io_init_wq_offload(ctx, task);
77         if (IS_ERR(tctx->io_wq)) {
78                 ret = PTR_ERR(tctx->io_wq);
79                 percpu_counter_destroy(&tctx->inflight);
80                 kfree(tctx);
81                 return ret;
82         }
83
84         xa_init(&tctx->xa);
85         init_waitqueue_head(&tctx->wait);
86         atomic_set(&tctx->in_idle, 0);
87         atomic_set(&tctx->inflight_tracked, 0);
88         task->io_uring = tctx;
89         init_llist_head(&tctx->task_list);
90         init_task_work(&tctx->task_work, tctx_task_work);
91         return 0;
92 }
93
94 static int io_register_submitter(struct io_ring_ctx *ctx)
95 {
96         int ret = 0;
97
98         mutex_lock(&ctx->uring_lock);
99         if (!ctx->submitter_task)
100                 ctx->submitter_task = get_task_struct(current);
101         else if (ctx->submitter_task != current)
102                 ret = -EEXIST;
103         mutex_unlock(&ctx->uring_lock);
104
105         return ret;
106 }
107
108 int __io_uring_add_tctx_node(struct io_ring_ctx *ctx, bool submitter)
109 {
110         struct io_uring_task *tctx = current->io_uring;
111         struct io_tctx_node *node;
112         int ret;
113
114         if ((ctx->flags & IORING_SETUP_SINGLE_ISSUER) && submitter) {
115                 ret = io_register_submitter(ctx);
116                 if (ret)
117                         return ret;
118         }
119
120         if (unlikely(!tctx)) {
121                 ret = io_uring_alloc_task_context(current, ctx);
122                 if (unlikely(ret))
123                         return ret;
124
125                 tctx = current->io_uring;
126                 if (ctx->iowq_limits_set) {
127                         unsigned int limits[2] = { ctx->iowq_limits[0],
128                                                    ctx->iowq_limits[1], };
129
130                         ret = io_wq_max_workers(tctx->io_wq, limits);
131                         if (ret)
132                                 return ret;
133                 }
134         }
135         if (!xa_load(&tctx->xa, (unsigned long)ctx)) {
136                 node = kmalloc(sizeof(*node), GFP_KERNEL);
137                 if (!node)
138                         return -ENOMEM;
139                 node->ctx = ctx;
140                 node->task = current;
141
142                 ret = xa_err(xa_store(&tctx->xa, (unsigned long)ctx,
143                                         node, GFP_KERNEL));
144                 if (ret) {
145                         kfree(node);
146                         return ret;
147                 }
148
149                 mutex_lock(&ctx->uring_lock);
150                 list_add(&node->ctx_node, &ctx->tctx_list);
151                 mutex_unlock(&ctx->uring_lock);
152         }
153         if (submitter)
154                 tctx->last = ctx;
155         return 0;
156 }
157
158 /*
159  * Remove this io_uring_file -> task mapping.
160  */
161 __cold void io_uring_del_tctx_node(unsigned long index)
162 {
163         struct io_uring_task *tctx = current->io_uring;
164         struct io_tctx_node *node;
165
166         if (!tctx)
167                 return;
168         node = xa_erase(&tctx->xa, index);
169         if (!node)
170                 return;
171
172         WARN_ON_ONCE(current != node->task);
173         WARN_ON_ONCE(list_empty(&node->ctx_node));
174
175         mutex_lock(&node->ctx->uring_lock);
176         list_del(&node->ctx_node);
177         mutex_unlock(&node->ctx->uring_lock);
178
179         if (tctx->last == node->ctx)
180                 tctx->last = NULL;
181         kfree(node);
182 }
183
184 __cold void io_uring_clean_tctx(struct io_uring_task *tctx)
185 {
186         struct io_wq *wq = tctx->io_wq;
187         struct io_tctx_node *node;
188         unsigned long index;
189
190         xa_for_each(&tctx->xa, index, node) {
191                 io_uring_del_tctx_node(index);
192                 cond_resched();
193         }
194         if (wq) {
195                 /*
196                  * Must be after io_uring_del_tctx_node() (removes nodes under
197                  * uring_lock) to avoid race with io_uring_try_cancel_iowq().
198                  */
199                 io_wq_put_and_exit(wq);
200                 tctx->io_wq = NULL;
201         }
202 }
203
204 void io_uring_unreg_ringfd(void)
205 {
206         struct io_uring_task *tctx = current->io_uring;
207         int i;
208
209         for (i = 0; i < IO_RINGFD_REG_MAX; i++) {
210                 if (tctx->registered_rings[i]) {
211                         fput(tctx->registered_rings[i]);
212                         tctx->registered_rings[i] = NULL;
213                 }
214         }
215 }
216
217 static int io_ring_add_registered_fd(struct io_uring_task *tctx, int fd,
218                                      int start, int end)
219 {
220         struct file *file;
221         int offset;
222
223         for (offset = start; offset < end; offset++) {
224                 offset = array_index_nospec(offset, IO_RINGFD_REG_MAX);
225                 if (tctx->registered_rings[offset])
226                         continue;
227
228                 file = fget(fd);
229                 if (!file) {
230                         return -EBADF;
231                 } else if (!io_is_uring_fops(file)) {
232                         fput(file);
233                         return -EOPNOTSUPP;
234                 }
235                 tctx->registered_rings[offset] = file;
236                 return offset;
237         }
238
239         return -EBUSY;
240 }
241
242 /*
243  * Register a ring fd to avoid fdget/fdput for each io_uring_enter()
244  * invocation. User passes in an array of struct io_uring_rsrc_update
245  * with ->data set to the ring_fd, and ->offset given for the desired
246  * index. If no index is desired, application may set ->offset == -1U
247  * and we'll find an available index. Returns number of entries
248  * successfully processed, or < 0 on error if none were processed.
249  */
250 int io_ringfd_register(struct io_ring_ctx *ctx, void __user *__arg,
251                        unsigned nr_args)
252 {
253         struct io_uring_rsrc_update __user *arg = __arg;
254         struct io_uring_rsrc_update reg;
255         struct io_uring_task *tctx;
256         int ret, i;
257
258         if (!nr_args || nr_args > IO_RINGFD_REG_MAX)
259                 return -EINVAL;
260
261         mutex_unlock(&ctx->uring_lock);
262         ret = __io_uring_add_tctx_node(ctx, false);
263         mutex_lock(&ctx->uring_lock);
264         if (ret)
265                 return ret;
266
267         tctx = current->io_uring;
268         for (i = 0; i < nr_args; i++) {
269                 int start, end;
270
271                 if (copy_from_user(&reg, &arg[i], sizeof(reg))) {
272                         ret = -EFAULT;
273                         break;
274                 }
275
276                 if (reg.resv) {
277                         ret = -EINVAL;
278                         break;
279                 }
280
281                 if (reg.offset == -1U) {
282                         start = 0;
283                         end = IO_RINGFD_REG_MAX;
284                 } else {
285                         if (reg.offset >= IO_RINGFD_REG_MAX) {
286                                 ret = -EINVAL;
287                                 break;
288                         }
289                         start = reg.offset;
290                         end = start + 1;
291                 }
292
293                 ret = io_ring_add_registered_fd(tctx, reg.data, start, end);
294                 if (ret < 0)
295                         break;
296
297                 reg.offset = ret;
298                 if (copy_to_user(&arg[i], &reg, sizeof(reg))) {
299                         fput(tctx->registered_rings[reg.offset]);
300                         tctx->registered_rings[reg.offset] = NULL;
301                         ret = -EFAULT;
302                         break;
303                 }
304         }
305
306         return i ? i : ret;
307 }
308
309 int io_ringfd_unregister(struct io_ring_ctx *ctx, void __user *__arg,
310                          unsigned nr_args)
311 {
312         struct io_uring_rsrc_update __user *arg = __arg;
313         struct io_uring_task *tctx = current->io_uring;
314         struct io_uring_rsrc_update reg;
315         int ret = 0, i;
316
317         if (!nr_args || nr_args > IO_RINGFD_REG_MAX)
318                 return -EINVAL;
319         if (!tctx)
320                 return 0;
321
322         for (i = 0; i < nr_args; i++) {
323                 if (copy_from_user(&reg, &arg[i], sizeof(reg))) {
324                         ret = -EFAULT;
325                         break;
326                 }
327                 if (reg.resv || reg.data || reg.offset >= IO_RINGFD_REG_MAX) {
328                         ret = -EINVAL;
329                         break;
330                 }
331
332                 reg.offset = array_index_nospec(reg.offset, IO_RINGFD_REG_MAX);
333                 if (tctx->registered_rings[reg.offset]) {
334                         fput(tctx->registered_rings[reg.offset]);
335                         tctx->registered_rings[reg.offset] = NULL;
336                 }
337         }
338
339         return i ? i : ret;
340 }