zonefs: convert zonefs to use the new mount api
[sfrench/cifs-2.6.git] / kernel / bpf / task_iter.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2020 Facebook */
3
4 #include <linux/init.h>
5 #include <linux/namei.h>
6 #include <linux/pid_namespace.h>
7 #include <linux/fs.h>
8 #include <linux/fdtable.h>
9 #include <linux/filter.h>
10 #include <linux/bpf_mem_alloc.h>
11 #include <linux/btf_ids.h>
12 #include <linux/mm_types.h>
13 #include "mmap_unlock_work.h"
14
15 static const char * const iter_task_type_names[] = {
16         "ALL",
17         "TID",
18         "PID",
19 };
20
21 struct bpf_iter_seq_task_common {
22         struct pid_namespace *ns;
23         enum bpf_iter_task_type type;
24         u32 pid;
25         u32 pid_visiting;
26 };
27
28 struct bpf_iter_seq_task_info {
29         /* The first field must be struct bpf_iter_seq_task_common.
30          * this is assumed by {init, fini}_seq_pidns() callback functions.
31          */
32         struct bpf_iter_seq_task_common common;
33         u32 tid;
34 };
35
36 static struct task_struct *task_group_seq_get_next(struct bpf_iter_seq_task_common *common,
37                                                    u32 *tid,
38                                                    bool skip_if_dup_files)
39 {
40         struct task_struct *task;
41         struct pid *pid;
42         u32 next_tid;
43
44         if (!*tid) {
45                 /* The first time, the iterator calls this function. */
46                 pid = find_pid_ns(common->pid, common->ns);
47                 task = get_pid_task(pid, PIDTYPE_TGID);
48                 if (!task)
49                         return NULL;
50
51                 *tid = common->pid;
52                 common->pid_visiting = common->pid;
53
54                 return task;
55         }
56
57         /* If the control returns to user space and comes back to the
58          * kernel again, *tid and common->pid_visiting should be the
59          * same for task_seq_start() to pick up the correct task.
60          */
61         if (*tid == common->pid_visiting) {
62                 pid = find_pid_ns(common->pid_visiting, common->ns);
63                 task = get_pid_task(pid, PIDTYPE_PID);
64
65                 return task;
66         }
67
68         task = find_task_by_pid_ns(common->pid_visiting, common->ns);
69         if (!task)
70                 return NULL;
71
72 retry:
73         task = __next_thread(task);
74         if (!task)
75                 return NULL;
76
77         next_tid = __task_pid_nr_ns(task, PIDTYPE_PID, common->ns);
78         if (!next_tid)
79                 goto retry;
80
81         if (skip_if_dup_files && task->files == task->group_leader->files)
82                 goto retry;
83
84         *tid = common->pid_visiting = next_tid;
85         get_task_struct(task);
86         return task;
87 }
88
89 static struct task_struct *task_seq_get_next(struct bpf_iter_seq_task_common *common,
90                                              u32 *tid,
91                                              bool skip_if_dup_files)
92 {
93         struct task_struct *task = NULL;
94         struct pid *pid;
95
96         if (common->type == BPF_TASK_ITER_TID) {
97                 if (*tid && *tid != common->pid)
98                         return NULL;
99                 rcu_read_lock();
100                 pid = find_pid_ns(common->pid, common->ns);
101                 if (pid) {
102                         task = get_pid_task(pid, PIDTYPE_TGID);
103                         *tid = common->pid;
104                 }
105                 rcu_read_unlock();
106
107                 return task;
108         }
109
110         if (common->type == BPF_TASK_ITER_TGID) {
111                 rcu_read_lock();
112                 task = task_group_seq_get_next(common, tid, skip_if_dup_files);
113                 rcu_read_unlock();
114
115                 return task;
116         }
117
118         rcu_read_lock();
119 retry:
120         pid = find_ge_pid(*tid, common->ns);
121         if (pid) {
122                 *tid = pid_nr_ns(pid, common->ns);
123                 task = get_pid_task(pid, PIDTYPE_PID);
124                 if (!task) {
125                         ++*tid;
126                         goto retry;
127                 } else if (skip_if_dup_files && !thread_group_leader(task) &&
128                            task->files == task->group_leader->files) {
129                         put_task_struct(task);
130                         task = NULL;
131                         ++*tid;
132                         goto retry;
133                 }
134         }
135         rcu_read_unlock();
136
137         return task;
138 }
139
140 static void *task_seq_start(struct seq_file *seq, loff_t *pos)
141 {
142         struct bpf_iter_seq_task_info *info = seq->private;
143         struct task_struct *task;
144
145         task = task_seq_get_next(&info->common, &info->tid, false);
146         if (!task)
147                 return NULL;
148
149         if (*pos == 0)
150                 ++*pos;
151         return task;
152 }
153
154 static void *task_seq_next(struct seq_file *seq, void *v, loff_t *pos)
155 {
156         struct bpf_iter_seq_task_info *info = seq->private;
157         struct task_struct *task;
158
159         ++*pos;
160         ++info->tid;
161         put_task_struct((struct task_struct *)v);
162         task = task_seq_get_next(&info->common, &info->tid, false);
163         if (!task)
164                 return NULL;
165
166         return task;
167 }
168
169 struct bpf_iter__task {
170         __bpf_md_ptr(struct bpf_iter_meta *, meta);
171         __bpf_md_ptr(struct task_struct *, task);
172 };
173
174 DEFINE_BPF_ITER_FUNC(task, struct bpf_iter_meta *meta, struct task_struct *task)
175
176 static int __task_seq_show(struct seq_file *seq, struct task_struct *task,
177                            bool in_stop)
178 {
179         struct bpf_iter_meta meta;
180         struct bpf_iter__task ctx;
181         struct bpf_prog *prog;
182
183         meta.seq = seq;
184         prog = bpf_iter_get_info(&meta, in_stop);
185         if (!prog)
186                 return 0;
187
188         ctx.meta = &meta;
189         ctx.task = task;
190         return bpf_iter_run_prog(prog, &ctx);
191 }
192
193 static int task_seq_show(struct seq_file *seq, void *v)
194 {
195         return __task_seq_show(seq, v, false);
196 }
197
198 static void task_seq_stop(struct seq_file *seq, void *v)
199 {
200         if (!v)
201                 (void)__task_seq_show(seq, v, true);
202         else
203                 put_task_struct((struct task_struct *)v);
204 }
205
206 static int bpf_iter_attach_task(struct bpf_prog *prog,
207                                 union bpf_iter_link_info *linfo,
208                                 struct bpf_iter_aux_info *aux)
209 {
210         unsigned int flags;
211         struct pid *pid;
212         pid_t tgid;
213
214         if ((!!linfo->task.tid + !!linfo->task.pid + !!linfo->task.pid_fd) > 1)
215                 return -EINVAL;
216
217         aux->task.type = BPF_TASK_ITER_ALL;
218         if (linfo->task.tid != 0) {
219                 aux->task.type = BPF_TASK_ITER_TID;
220                 aux->task.pid = linfo->task.tid;
221         }
222         if (linfo->task.pid != 0) {
223                 aux->task.type = BPF_TASK_ITER_TGID;
224                 aux->task.pid = linfo->task.pid;
225         }
226         if (linfo->task.pid_fd != 0) {
227                 aux->task.type = BPF_TASK_ITER_TGID;
228
229                 pid = pidfd_get_pid(linfo->task.pid_fd, &flags);
230                 if (IS_ERR(pid))
231                         return PTR_ERR(pid);
232
233                 tgid = pid_nr_ns(pid, task_active_pid_ns(current));
234                 aux->task.pid = tgid;
235                 put_pid(pid);
236         }
237
238         return 0;
239 }
240
241 static const struct seq_operations task_seq_ops = {
242         .start  = task_seq_start,
243         .next   = task_seq_next,
244         .stop   = task_seq_stop,
245         .show   = task_seq_show,
246 };
247
248 struct bpf_iter_seq_task_file_info {
249         /* The first field must be struct bpf_iter_seq_task_common.
250          * this is assumed by {init, fini}_seq_pidns() callback functions.
251          */
252         struct bpf_iter_seq_task_common common;
253         struct task_struct *task;
254         u32 tid;
255         u32 fd;
256 };
257
258 static struct file *
259 task_file_seq_get_next(struct bpf_iter_seq_task_file_info *info)
260 {
261         u32 saved_tid = info->tid;
262         struct task_struct *curr_task;
263         unsigned int curr_fd = info->fd;
264
265         /* If this function returns a non-NULL file object,
266          * it held a reference to the task/file.
267          * Otherwise, it does not hold any reference.
268          */
269 again:
270         if (info->task) {
271                 curr_task = info->task;
272                 curr_fd = info->fd;
273         } else {
274                 curr_task = task_seq_get_next(&info->common, &info->tid, true);
275                 if (!curr_task) {
276                         info->task = NULL;
277                         return NULL;
278                 }
279
280                 /* set info->task */
281                 info->task = curr_task;
282                 if (saved_tid == info->tid)
283                         curr_fd = info->fd;
284                 else
285                         curr_fd = 0;
286         }
287
288         rcu_read_lock();
289         for (;; curr_fd++) {
290                 struct file *f;
291                 f = task_lookup_next_fdget_rcu(curr_task, &curr_fd);
292                 if (!f)
293                         break;
294
295                 /* set info->fd */
296                 info->fd = curr_fd;
297                 rcu_read_unlock();
298                 return f;
299         }
300
301         /* the current task is done, go to the next task */
302         rcu_read_unlock();
303         put_task_struct(curr_task);
304
305         if (info->common.type == BPF_TASK_ITER_TID) {
306                 info->task = NULL;
307                 return NULL;
308         }
309
310         info->task = NULL;
311         info->fd = 0;
312         saved_tid = ++(info->tid);
313         goto again;
314 }
315
316 static void *task_file_seq_start(struct seq_file *seq, loff_t *pos)
317 {
318         struct bpf_iter_seq_task_file_info *info = seq->private;
319         struct file *file;
320
321         info->task = NULL;
322         file = task_file_seq_get_next(info);
323         if (file && *pos == 0)
324                 ++*pos;
325
326         return file;
327 }
328
329 static void *task_file_seq_next(struct seq_file *seq, void *v, loff_t *pos)
330 {
331         struct bpf_iter_seq_task_file_info *info = seq->private;
332
333         ++*pos;
334         ++info->fd;
335         fput((struct file *)v);
336         return task_file_seq_get_next(info);
337 }
338
339 struct bpf_iter__task_file {
340         __bpf_md_ptr(struct bpf_iter_meta *, meta);
341         __bpf_md_ptr(struct task_struct *, task);
342         u32 fd __aligned(8);
343         __bpf_md_ptr(struct file *, file);
344 };
345
346 DEFINE_BPF_ITER_FUNC(task_file, struct bpf_iter_meta *meta,
347                      struct task_struct *task, u32 fd,
348                      struct file *file)
349
350 static int __task_file_seq_show(struct seq_file *seq, struct file *file,
351                                 bool in_stop)
352 {
353         struct bpf_iter_seq_task_file_info *info = seq->private;
354         struct bpf_iter__task_file ctx;
355         struct bpf_iter_meta meta;
356         struct bpf_prog *prog;
357
358         meta.seq = seq;
359         prog = bpf_iter_get_info(&meta, in_stop);
360         if (!prog)
361                 return 0;
362
363         ctx.meta = &meta;
364         ctx.task = info->task;
365         ctx.fd = info->fd;
366         ctx.file = file;
367         return bpf_iter_run_prog(prog, &ctx);
368 }
369
370 static int task_file_seq_show(struct seq_file *seq, void *v)
371 {
372         return __task_file_seq_show(seq, v, false);
373 }
374
375 static void task_file_seq_stop(struct seq_file *seq, void *v)
376 {
377         struct bpf_iter_seq_task_file_info *info = seq->private;
378
379         if (!v) {
380                 (void)__task_file_seq_show(seq, v, true);
381         } else {
382                 fput((struct file *)v);
383                 put_task_struct(info->task);
384                 info->task = NULL;
385         }
386 }
387
388 static int init_seq_pidns(void *priv_data, struct bpf_iter_aux_info *aux)
389 {
390         struct bpf_iter_seq_task_common *common = priv_data;
391
392         common->ns = get_pid_ns(task_active_pid_ns(current));
393         common->type = aux->task.type;
394         common->pid = aux->task.pid;
395
396         return 0;
397 }
398
399 static void fini_seq_pidns(void *priv_data)
400 {
401         struct bpf_iter_seq_task_common *common = priv_data;
402
403         put_pid_ns(common->ns);
404 }
405
406 static const struct seq_operations task_file_seq_ops = {
407         .start  = task_file_seq_start,
408         .next   = task_file_seq_next,
409         .stop   = task_file_seq_stop,
410         .show   = task_file_seq_show,
411 };
412
413 struct bpf_iter_seq_task_vma_info {
414         /* The first field must be struct bpf_iter_seq_task_common.
415          * this is assumed by {init, fini}_seq_pidns() callback functions.
416          */
417         struct bpf_iter_seq_task_common common;
418         struct task_struct *task;
419         struct mm_struct *mm;
420         struct vm_area_struct *vma;
421         u32 tid;
422         unsigned long prev_vm_start;
423         unsigned long prev_vm_end;
424 };
425
426 enum bpf_task_vma_iter_find_op {
427         task_vma_iter_first_vma,   /* use find_vma() with addr 0 */
428         task_vma_iter_next_vma,    /* use vma_next() with curr_vma */
429         task_vma_iter_find_vma,    /* use find_vma() to find next vma */
430 };
431
432 static struct vm_area_struct *
433 task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
434 {
435         enum bpf_task_vma_iter_find_op op;
436         struct vm_area_struct *curr_vma;
437         struct task_struct *curr_task;
438         struct mm_struct *curr_mm;
439         u32 saved_tid = info->tid;
440
441         /* If this function returns a non-NULL vma, it holds a reference to
442          * the task_struct, holds a refcount on mm->mm_users, and holds
443          * read lock on vma->mm->mmap_lock.
444          * If this function returns NULL, it does not hold any reference or
445          * lock.
446          */
447         if (info->task) {
448                 curr_task = info->task;
449                 curr_vma = info->vma;
450                 curr_mm = info->mm;
451                 /* In case of lock contention, drop mmap_lock to unblock
452                  * the writer.
453                  *
454                  * After relock, call find(mm, prev_vm_end - 1) to find
455                  * new vma to process.
456                  *
457                  *   +------+------+-----------+
458                  *   | VMA1 | VMA2 | VMA3      |
459                  *   +------+------+-----------+
460                  *   |      |      |           |
461                  *  4k     8k     16k         400k
462                  *
463                  * For example, curr_vma == VMA2. Before unlock, we set
464                  *
465                  *    prev_vm_start = 8k
466                  *    prev_vm_end   = 16k
467                  *
468                  * There are a few cases:
469                  *
470                  * 1) VMA2 is freed, but VMA3 exists.
471                  *
472                  *    find_vma() will return VMA3, just process VMA3.
473                  *
474                  * 2) VMA2 still exists.
475                  *
476                  *    find_vma() will return VMA2, process VMA2->next.
477                  *
478                  * 3) no more vma in this mm.
479                  *
480                  *    Process the next task.
481                  *
482                  * 4) find_vma() returns a different vma, VMA2'.
483                  *
484                  *    4.1) If VMA2 covers same range as VMA2', skip VMA2',
485                  *         because we already covered the range;
486                  *    4.2) VMA2 and VMA2' covers different ranges, process
487                  *         VMA2'.
488                  */
489                 if (mmap_lock_is_contended(curr_mm)) {
490                         info->prev_vm_start = curr_vma->vm_start;
491                         info->prev_vm_end = curr_vma->vm_end;
492                         op = task_vma_iter_find_vma;
493                         mmap_read_unlock(curr_mm);
494                         if (mmap_read_lock_killable(curr_mm)) {
495                                 mmput(curr_mm);
496                                 goto finish;
497                         }
498                 } else {
499                         op = task_vma_iter_next_vma;
500                 }
501         } else {
502 again:
503                 curr_task = task_seq_get_next(&info->common, &info->tid, true);
504                 if (!curr_task) {
505                         info->tid++;
506                         goto finish;
507                 }
508
509                 if (saved_tid != info->tid) {
510                         /* new task, process the first vma */
511                         op = task_vma_iter_first_vma;
512                 } else {
513                         /* Found the same tid, which means the user space
514                          * finished data in previous buffer and read more.
515                          * We dropped mmap_lock before returning to user
516                          * space, so it is necessary to use find_vma() to
517                          * find the next vma to process.
518                          */
519                         op = task_vma_iter_find_vma;
520                 }
521
522                 curr_mm = get_task_mm(curr_task);
523                 if (!curr_mm)
524                         goto next_task;
525
526                 if (mmap_read_lock_killable(curr_mm)) {
527                         mmput(curr_mm);
528                         goto finish;
529                 }
530         }
531
532         switch (op) {
533         case task_vma_iter_first_vma:
534                 curr_vma = find_vma(curr_mm, 0);
535                 break;
536         case task_vma_iter_next_vma:
537                 curr_vma = find_vma(curr_mm, curr_vma->vm_end);
538                 break;
539         case task_vma_iter_find_vma:
540                 /* We dropped mmap_lock so it is necessary to use find_vma
541                  * to find the next vma. This is similar to the  mechanism
542                  * in show_smaps_rollup().
543                  */
544                 curr_vma = find_vma(curr_mm, info->prev_vm_end - 1);
545                 /* case 1) and 4.2) above just use curr_vma */
546
547                 /* check for case 2) or case 4.1) above */
548                 if (curr_vma &&
549                     curr_vma->vm_start == info->prev_vm_start &&
550                     curr_vma->vm_end == info->prev_vm_end)
551                         curr_vma = find_vma(curr_mm, curr_vma->vm_end);
552                 break;
553         }
554         if (!curr_vma) {
555                 /* case 3) above, or case 2) 4.1) with vma->next == NULL */
556                 mmap_read_unlock(curr_mm);
557                 mmput(curr_mm);
558                 goto next_task;
559         }
560         info->task = curr_task;
561         info->vma = curr_vma;
562         info->mm = curr_mm;
563         return curr_vma;
564
565 next_task:
566         if (info->common.type == BPF_TASK_ITER_TID)
567                 goto finish;
568
569         put_task_struct(curr_task);
570         info->task = NULL;
571         info->mm = NULL;
572         info->tid++;
573         goto again;
574
575 finish:
576         if (curr_task)
577                 put_task_struct(curr_task);
578         info->task = NULL;
579         info->vma = NULL;
580         info->mm = NULL;
581         return NULL;
582 }
583
584 static void *task_vma_seq_start(struct seq_file *seq, loff_t *pos)
585 {
586         struct bpf_iter_seq_task_vma_info *info = seq->private;
587         struct vm_area_struct *vma;
588
589         vma = task_vma_seq_get_next(info);
590         if (vma && *pos == 0)
591                 ++*pos;
592
593         return vma;
594 }
595
596 static void *task_vma_seq_next(struct seq_file *seq, void *v, loff_t *pos)
597 {
598         struct bpf_iter_seq_task_vma_info *info = seq->private;
599
600         ++*pos;
601         return task_vma_seq_get_next(info);
602 }
603
604 struct bpf_iter__task_vma {
605         __bpf_md_ptr(struct bpf_iter_meta *, meta);
606         __bpf_md_ptr(struct task_struct *, task);
607         __bpf_md_ptr(struct vm_area_struct *, vma);
608 };
609
610 DEFINE_BPF_ITER_FUNC(task_vma, struct bpf_iter_meta *meta,
611                      struct task_struct *task, struct vm_area_struct *vma)
612
613 static int __task_vma_seq_show(struct seq_file *seq, bool in_stop)
614 {
615         struct bpf_iter_seq_task_vma_info *info = seq->private;
616         struct bpf_iter__task_vma ctx;
617         struct bpf_iter_meta meta;
618         struct bpf_prog *prog;
619
620         meta.seq = seq;
621         prog = bpf_iter_get_info(&meta, in_stop);
622         if (!prog)
623                 return 0;
624
625         ctx.meta = &meta;
626         ctx.task = info->task;
627         ctx.vma = info->vma;
628         return bpf_iter_run_prog(prog, &ctx);
629 }
630
631 static int task_vma_seq_show(struct seq_file *seq, void *v)
632 {
633         return __task_vma_seq_show(seq, false);
634 }
635
636 static void task_vma_seq_stop(struct seq_file *seq, void *v)
637 {
638         struct bpf_iter_seq_task_vma_info *info = seq->private;
639
640         if (!v) {
641                 (void)__task_vma_seq_show(seq, true);
642         } else {
643                 /* info->vma has not been seen by the BPF program. If the
644                  * user space reads more, task_vma_seq_get_next should
645                  * return this vma again. Set prev_vm_start to ~0UL,
646                  * so that we don't skip the vma returned by the next
647                  * find_vma() (case task_vma_iter_find_vma in
648                  * task_vma_seq_get_next()).
649                  */
650                 info->prev_vm_start = ~0UL;
651                 info->prev_vm_end = info->vma->vm_end;
652                 mmap_read_unlock(info->mm);
653                 mmput(info->mm);
654                 info->mm = NULL;
655                 put_task_struct(info->task);
656                 info->task = NULL;
657         }
658 }
659
660 static const struct seq_operations task_vma_seq_ops = {
661         .start  = task_vma_seq_start,
662         .next   = task_vma_seq_next,
663         .stop   = task_vma_seq_stop,
664         .show   = task_vma_seq_show,
665 };
666
667 static const struct bpf_iter_seq_info task_seq_info = {
668         .seq_ops                = &task_seq_ops,
669         .init_seq_private       = init_seq_pidns,
670         .fini_seq_private       = fini_seq_pidns,
671         .seq_priv_size          = sizeof(struct bpf_iter_seq_task_info),
672 };
673
674 static int bpf_iter_fill_link_info(const struct bpf_iter_aux_info *aux, struct bpf_link_info *info)
675 {
676         switch (aux->task.type) {
677         case BPF_TASK_ITER_TID:
678                 info->iter.task.tid = aux->task.pid;
679                 break;
680         case BPF_TASK_ITER_TGID:
681                 info->iter.task.pid = aux->task.pid;
682                 break;
683         default:
684                 break;
685         }
686         return 0;
687 }
688
689 static void bpf_iter_task_show_fdinfo(const struct bpf_iter_aux_info *aux, struct seq_file *seq)
690 {
691         seq_printf(seq, "task_type:\t%s\n", iter_task_type_names[aux->task.type]);
692         if (aux->task.type == BPF_TASK_ITER_TID)
693                 seq_printf(seq, "tid:\t%u\n", aux->task.pid);
694         else if (aux->task.type == BPF_TASK_ITER_TGID)
695                 seq_printf(seq, "pid:\t%u\n", aux->task.pid);
696 }
697
698 static struct bpf_iter_reg task_reg_info = {
699         .target                 = "task",
700         .attach_target          = bpf_iter_attach_task,
701         .feature                = BPF_ITER_RESCHED,
702         .ctx_arg_info_size      = 1,
703         .ctx_arg_info           = {
704                 { offsetof(struct bpf_iter__task, task),
705                   PTR_TO_BTF_ID_OR_NULL | PTR_TRUSTED },
706         },
707         .seq_info               = &task_seq_info,
708         .fill_link_info         = bpf_iter_fill_link_info,
709         .show_fdinfo            = bpf_iter_task_show_fdinfo,
710 };
711
712 static const struct bpf_iter_seq_info task_file_seq_info = {
713         .seq_ops                = &task_file_seq_ops,
714         .init_seq_private       = init_seq_pidns,
715         .fini_seq_private       = fini_seq_pidns,
716         .seq_priv_size          = sizeof(struct bpf_iter_seq_task_file_info),
717 };
718
719 static struct bpf_iter_reg task_file_reg_info = {
720         .target                 = "task_file",
721         .attach_target          = bpf_iter_attach_task,
722         .feature                = BPF_ITER_RESCHED,
723         .ctx_arg_info_size      = 2,
724         .ctx_arg_info           = {
725                 { offsetof(struct bpf_iter__task_file, task),
726                   PTR_TO_BTF_ID_OR_NULL },
727                 { offsetof(struct bpf_iter__task_file, file),
728                   PTR_TO_BTF_ID_OR_NULL },
729         },
730         .seq_info               = &task_file_seq_info,
731         .fill_link_info         = bpf_iter_fill_link_info,
732         .show_fdinfo            = bpf_iter_task_show_fdinfo,
733 };
734
735 static const struct bpf_iter_seq_info task_vma_seq_info = {
736         .seq_ops                = &task_vma_seq_ops,
737         .init_seq_private       = init_seq_pidns,
738         .fini_seq_private       = fini_seq_pidns,
739         .seq_priv_size          = sizeof(struct bpf_iter_seq_task_vma_info),
740 };
741
742 static struct bpf_iter_reg task_vma_reg_info = {
743         .target                 = "task_vma",
744         .attach_target          = bpf_iter_attach_task,
745         .feature                = BPF_ITER_RESCHED,
746         .ctx_arg_info_size      = 2,
747         .ctx_arg_info           = {
748                 { offsetof(struct bpf_iter__task_vma, task),
749                   PTR_TO_BTF_ID_OR_NULL },
750                 { offsetof(struct bpf_iter__task_vma, vma),
751                   PTR_TO_BTF_ID_OR_NULL },
752         },
753         .seq_info               = &task_vma_seq_info,
754         .fill_link_info         = bpf_iter_fill_link_info,
755         .show_fdinfo            = bpf_iter_task_show_fdinfo,
756 };
757
758 BPF_CALL_5(bpf_find_vma, struct task_struct *, task, u64, start,
759            bpf_callback_t, callback_fn, void *, callback_ctx, u64, flags)
760 {
761         struct mmap_unlock_irq_work *work = NULL;
762         struct vm_area_struct *vma;
763         bool irq_work_busy = false;
764         struct mm_struct *mm;
765         int ret = -ENOENT;
766
767         if (flags)
768                 return -EINVAL;
769
770         if (!task)
771                 return -ENOENT;
772
773         mm = task->mm;
774         if (!mm)
775                 return -ENOENT;
776
777         irq_work_busy = bpf_mmap_unlock_get_irq_work(&work);
778
779         if (irq_work_busy || !mmap_read_trylock(mm))
780                 return -EBUSY;
781
782         vma = find_vma(mm, start);
783
784         if (vma && vma->vm_start <= start && vma->vm_end > start) {
785                 callback_fn((u64)(long)task, (u64)(long)vma,
786                             (u64)(long)callback_ctx, 0, 0);
787                 ret = 0;
788         }
789         bpf_mmap_unlock_mm(work, mm);
790         return ret;
791 }
792
793 const struct bpf_func_proto bpf_find_vma_proto = {
794         .func           = bpf_find_vma,
795         .ret_type       = RET_INTEGER,
796         .arg1_type      = ARG_PTR_TO_BTF_ID,
797         .arg1_btf_id    = &btf_tracing_ids[BTF_TRACING_TYPE_TASK],
798         .arg2_type      = ARG_ANYTHING,
799         .arg3_type      = ARG_PTR_TO_FUNC,
800         .arg4_type      = ARG_PTR_TO_STACK_OR_NULL,
801         .arg5_type      = ARG_ANYTHING,
802 };
803
804 struct bpf_iter_task_vma_kern_data {
805         struct task_struct *task;
806         struct mm_struct *mm;
807         struct mmap_unlock_irq_work *work;
808         struct vma_iterator vmi;
809 };
810
811 struct bpf_iter_task_vma {
812         /* opaque iterator state; having __u64 here allows to preserve correct
813          * alignment requirements in vmlinux.h, generated from BTF
814          */
815         __u64 __opaque[1];
816 } __attribute__((aligned(8)));
817
818 /* Non-opaque version of bpf_iter_task_vma */
819 struct bpf_iter_task_vma_kern {
820         struct bpf_iter_task_vma_kern_data *data;
821 } __attribute__((aligned(8)));
822
823 __bpf_kfunc_start_defs();
824
825 __bpf_kfunc int bpf_iter_task_vma_new(struct bpf_iter_task_vma *it,
826                                       struct task_struct *task, u64 addr)
827 {
828         struct bpf_iter_task_vma_kern *kit = (void *)it;
829         bool irq_work_busy = false;
830         int err;
831
832         BUILD_BUG_ON(sizeof(struct bpf_iter_task_vma_kern) != sizeof(struct bpf_iter_task_vma));
833         BUILD_BUG_ON(__alignof__(struct bpf_iter_task_vma_kern) != __alignof__(struct bpf_iter_task_vma));
834
835         /* is_iter_reg_valid_uninit guarantees that kit hasn't been initialized
836          * before, so non-NULL kit->data doesn't point to previously
837          * bpf_mem_alloc'd bpf_iter_task_vma_kern_data
838          */
839         kit->data = bpf_mem_alloc(&bpf_global_ma, sizeof(struct bpf_iter_task_vma_kern_data));
840         if (!kit->data)
841                 return -ENOMEM;
842
843         kit->data->task = get_task_struct(task);
844         kit->data->mm = task->mm;
845         if (!kit->data->mm) {
846                 err = -ENOENT;
847                 goto err_cleanup_iter;
848         }
849
850         /* kit->data->work == NULL is valid after bpf_mmap_unlock_get_irq_work */
851         irq_work_busy = bpf_mmap_unlock_get_irq_work(&kit->data->work);
852         if (irq_work_busy || !mmap_read_trylock(kit->data->mm)) {
853                 err = -EBUSY;
854                 goto err_cleanup_iter;
855         }
856
857         vma_iter_init(&kit->data->vmi, kit->data->mm, addr);
858         return 0;
859
860 err_cleanup_iter:
861         if (kit->data->task)
862                 put_task_struct(kit->data->task);
863         bpf_mem_free(&bpf_global_ma, kit->data);
864         /* NULL kit->data signals failed bpf_iter_task_vma initialization */
865         kit->data = NULL;
866         return err;
867 }
868
869 __bpf_kfunc struct vm_area_struct *bpf_iter_task_vma_next(struct bpf_iter_task_vma *it)
870 {
871         struct bpf_iter_task_vma_kern *kit = (void *)it;
872
873         if (!kit->data) /* bpf_iter_task_vma_new failed */
874                 return NULL;
875         return vma_next(&kit->data->vmi);
876 }
877
878 __bpf_kfunc void bpf_iter_task_vma_destroy(struct bpf_iter_task_vma *it)
879 {
880         struct bpf_iter_task_vma_kern *kit = (void *)it;
881
882         if (kit->data) {
883                 bpf_mmap_unlock_mm(kit->data->work, kit->data->mm);
884                 put_task_struct(kit->data->task);
885                 bpf_mem_free(&bpf_global_ma, kit->data);
886         }
887 }
888
889 __bpf_kfunc_end_defs();
890
891 #ifdef CONFIG_CGROUPS
892
893 struct bpf_iter_css_task {
894         __u64 __opaque[1];
895 } __attribute__((aligned(8)));
896
897 struct bpf_iter_css_task_kern {
898         struct css_task_iter *css_it;
899 } __attribute__((aligned(8)));
900
901 __bpf_kfunc_start_defs();
902
903 __bpf_kfunc int bpf_iter_css_task_new(struct bpf_iter_css_task *it,
904                 struct cgroup_subsys_state *css, unsigned int flags)
905 {
906         struct bpf_iter_css_task_kern *kit = (void *)it;
907
908         BUILD_BUG_ON(sizeof(struct bpf_iter_css_task_kern) != sizeof(struct bpf_iter_css_task));
909         BUILD_BUG_ON(__alignof__(struct bpf_iter_css_task_kern) !=
910                                         __alignof__(struct bpf_iter_css_task));
911         kit->css_it = NULL;
912         switch (flags) {
913         case CSS_TASK_ITER_PROCS | CSS_TASK_ITER_THREADED:
914         case CSS_TASK_ITER_PROCS:
915         case 0:
916                 break;
917         default:
918                 return -EINVAL;
919         }
920
921         kit->css_it = bpf_mem_alloc(&bpf_global_ma, sizeof(struct css_task_iter));
922         if (!kit->css_it)
923                 return -ENOMEM;
924         css_task_iter_start(css, flags, kit->css_it);
925         return 0;
926 }
927
928 __bpf_kfunc struct task_struct *bpf_iter_css_task_next(struct bpf_iter_css_task *it)
929 {
930         struct bpf_iter_css_task_kern *kit = (void *)it;
931
932         if (!kit->css_it)
933                 return NULL;
934         return css_task_iter_next(kit->css_it);
935 }
936
937 __bpf_kfunc void bpf_iter_css_task_destroy(struct bpf_iter_css_task *it)
938 {
939         struct bpf_iter_css_task_kern *kit = (void *)it;
940
941         if (!kit->css_it)
942                 return;
943         css_task_iter_end(kit->css_it);
944         bpf_mem_free(&bpf_global_ma, kit->css_it);
945 }
946
947 __bpf_kfunc_end_defs();
948
949 #endif /* CONFIG_CGROUPS */
950
951 struct bpf_iter_task {
952         __u64 __opaque[3];
953 } __attribute__((aligned(8)));
954
955 struct bpf_iter_task_kern {
956         struct task_struct *task;
957         struct task_struct *pos;
958         unsigned int flags;
959 } __attribute__((aligned(8)));
960
961 enum {
962         /* all process in the system */
963         BPF_TASK_ITER_ALL_PROCS,
964         /* all threads in the system */
965         BPF_TASK_ITER_ALL_THREADS,
966         /* all threads of a specific process */
967         BPF_TASK_ITER_PROC_THREADS
968 };
969
970 __bpf_kfunc_start_defs();
971
972 __bpf_kfunc int bpf_iter_task_new(struct bpf_iter_task *it,
973                 struct task_struct *task__nullable, unsigned int flags)
974 {
975         struct bpf_iter_task_kern *kit = (void *)it;
976
977         BUILD_BUG_ON(sizeof(struct bpf_iter_task_kern) > sizeof(struct bpf_iter_task));
978         BUILD_BUG_ON(__alignof__(struct bpf_iter_task_kern) !=
979                                         __alignof__(struct bpf_iter_task));
980
981         switch (flags) {
982         case BPF_TASK_ITER_ALL_THREADS:
983         case BPF_TASK_ITER_ALL_PROCS:
984                 break;
985         case BPF_TASK_ITER_PROC_THREADS:
986                 if (!task__nullable)
987                         return -EINVAL;
988                 break;
989         default:
990                 return -EINVAL;
991         }
992
993         if (flags == BPF_TASK_ITER_PROC_THREADS)
994                 kit->task = task__nullable;
995         else
996                 kit->task = &init_task;
997         kit->pos = kit->task;
998         kit->flags = flags;
999         return 0;
1000 }
1001
1002 __bpf_kfunc struct task_struct *bpf_iter_task_next(struct bpf_iter_task *it)
1003 {
1004         struct bpf_iter_task_kern *kit = (void *)it;
1005         struct task_struct *pos;
1006         unsigned int flags;
1007
1008         flags = kit->flags;
1009         pos = kit->pos;
1010
1011         if (!pos)
1012                 return pos;
1013
1014         if (flags == BPF_TASK_ITER_ALL_PROCS)
1015                 goto get_next_task;
1016
1017         kit->pos = __next_thread(kit->pos);
1018         if (kit->pos || flags == BPF_TASK_ITER_PROC_THREADS)
1019                 return pos;
1020
1021 get_next_task:
1022         kit->task = next_task(kit->task);
1023         if (kit->task == &init_task)
1024                 kit->pos = NULL;
1025         else
1026                 kit->pos = kit->task;
1027
1028         return pos;
1029 }
1030
1031 __bpf_kfunc void bpf_iter_task_destroy(struct bpf_iter_task *it)
1032 {
1033 }
1034
1035 __bpf_kfunc_end_defs();
1036
1037 DEFINE_PER_CPU(struct mmap_unlock_irq_work, mmap_unlock_work);
1038
1039 static void do_mmap_read_unlock(struct irq_work *entry)
1040 {
1041         struct mmap_unlock_irq_work *work;
1042
1043         if (WARN_ON_ONCE(IS_ENABLED(CONFIG_PREEMPT_RT)))
1044                 return;
1045
1046         work = container_of(entry, struct mmap_unlock_irq_work, irq_work);
1047         mmap_read_unlock_non_owner(work->mm);
1048 }
1049
1050 static int __init task_iter_init(void)
1051 {
1052         struct mmap_unlock_irq_work *work;
1053         int ret, cpu;
1054
1055         for_each_possible_cpu(cpu) {
1056                 work = per_cpu_ptr(&mmap_unlock_work, cpu);
1057                 init_irq_work(&work->irq_work, do_mmap_read_unlock);
1058         }
1059
1060         task_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
1061         ret = bpf_iter_reg_target(&task_reg_info);
1062         if (ret)
1063                 return ret;
1064
1065         task_file_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
1066         task_file_reg_info.ctx_arg_info[1].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_FILE];
1067         ret =  bpf_iter_reg_target(&task_file_reg_info);
1068         if (ret)
1069                 return ret;
1070
1071         task_vma_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
1072         task_vma_reg_info.ctx_arg_info[1].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_VMA];
1073         return bpf_iter_reg_target(&task_vma_reg_info);
1074 }
1075 late_initcall(task_iter_init);