Merge tag 'for-linus-iommufd' of git://git.kernel.org/pub/scm/linux/kernel/git/jgg...
[sfrench/cifs-2.6.git] / drivers / iommu / iommufd / selftest.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * Kernel side components to support tools/testing/selftests/iommu
5  */
6 #include <linux/slab.h>
7 #include <linux/iommu.h>
8 #include <linux/xarray.h>
9 #include <linux/file.h>
10 #include <linux/anon_inodes.h>
11 #include <linux/fault-inject.h>
12 #include <linux/platform_device.h>
13 #include <uapi/linux/iommufd.h>
14
15 #include "../iommu-priv.h"
16 #include "io_pagetable.h"
17 #include "iommufd_private.h"
18 #include "iommufd_test.h"
19
20 static DECLARE_FAULT_ATTR(fail_iommufd);
21 static struct dentry *dbgfs_root;
22 static struct platform_device *selftest_iommu_dev;
23 static const struct iommu_ops mock_ops;
24 static struct iommu_domain_ops domain_nested_ops;
25
26 size_t iommufd_test_memory_limit = 65536;
27
28 struct mock_bus_type {
29         struct bus_type bus;
30         struct notifier_block nb;
31 };
32
33 static struct mock_bus_type iommufd_mock_bus_type = {
34         .bus = {
35                 .name = "iommufd_mock",
36         },
37 };
38
39 static atomic_t mock_dev_num;
40
41 enum {
42         MOCK_DIRTY_TRACK = 1,
43         MOCK_IO_PAGE_SIZE = PAGE_SIZE / 2,
44
45         /*
46          * Like a real page table alignment requires the low bits of the address
47          * to be zero. xarray also requires the high bit to be zero, so we store
48          * the pfns shifted. The upper bits are used for metadata.
49          */
50         MOCK_PFN_MASK = ULONG_MAX / MOCK_IO_PAGE_SIZE,
51
52         _MOCK_PFN_START = MOCK_PFN_MASK + 1,
53         MOCK_PFN_START_IOVA = _MOCK_PFN_START,
54         MOCK_PFN_LAST_IOVA = _MOCK_PFN_START,
55         MOCK_PFN_DIRTY_IOVA = _MOCK_PFN_START << 1,
56 };
57
58 /*
59  * Syzkaller has trouble randomizing the correct iova to use since it is linked
60  * to the map ioctl's output, and it has no ide about that. So, simplify things.
61  * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
62  * value. This has a much smaller randomization space and syzkaller can hit it.
63  */
64 static unsigned long iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
65                                                 u64 *iova)
66 {
67         struct syz_layout {
68                 __u32 nth_area;
69                 __u32 offset;
70         };
71         struct syz_layout *syz = (void *)iova;
72         unsigned int nth = syz->nth_area;
73         struct iopt_area *area;
74
75         down_read(&iopt->iova_rwsem);
76         for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
77              area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
78                 if (nth == 0) {
79                         up_read(&iopt->iova_rwsem);
80                         return iopt_area_iova(area) + syz->offset;
81                 }
82                 nth--;
83         }
84         up_read(&iopt->iova_rwsem);
85
86         return 0;
87 }
88
89 void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
90                                    unsigned int ioas_id, u64 *iova, u32 *flags)
91 {
92         struct iommufd_ioas *ioas;
93
94         if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
95                 return;
96         *flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;
97
98         ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
99         if (IS_ERR(ioas))
100                 return;
101         *iova = iommufd_test_syz_conv_iova(&ioas->iopt, iova);
102         iommufd_put_object(ucmd->ictx, &ioas->obj);
103 }
104
105 struct mock_iommu_domain {
106         unsigned long flags;
107         struct iommu_domain domain;
108         struct xarray pfns;
109 };
110
111 struct mock_iommu_domain_nested {
112         struct iommu_domain domain;
113         struct mock_iommu_domain *parent;
114         u32 iotlb[MOCK_NESTED_DOMAIN_IOTLB_NUM];
115 };
116
117 enum selftest_obj_type {
118         TYPE_IDEV,
119 };
120
121 struct mock_dev {
122         struct device dev;
123         unsigned long flags;
124 };
125
126 struct selftest_obj {
127         struct iommufd_object obj;
128         enum selftest_obj_type type;
129
130         union {
131                 struct {
132                         struct iommufd_device *idev;
133                         struct iommufd_ctx *ictx;
134                         struct mock_dev *mock_dev;
135                 } idev;
136         };
137 };
138
139 static int mock_domain_nop_attach(struct iommu_domain *domain,
140                                   struct device *dev)
141 {
142         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
143
144         if (domain->dirty_ops && (mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY))
145                 return -EINVAL;
146
147         return 0;
148 }
149
150 static const struct iommu_domain_ops mock_blocking_ops = {
151         .attach_dev = mock_domain_nop_attach,
152 };
153
154 static struct iommu_domain mock_blocking_domain = {
155         .type = IOMMU_DOMAIN_BLOCKED,
156         .ops = &mock_blocking_ops,
157 };
158
159 static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
160 {
161         struct iommu_test_hw_info *info;
162
163         info = kzalloc(sizeof(*info), GFP_KERNEL);
164         if (!info)
165                 return ERR_PTR(-ENOMEM);
166
167         info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
168         *length = sizeof(*info);
169         *type = IOMMU_HW_INFO_TYPE_SELFTEST;
170
171         return info;
172 }
173
174 static int mock_domain_set_dirty_tracking(struct iommu_domain *domain,
175                                           bool enable)
176 {
177         struct mock_iommu_domain *mock =
178                 container_of(domain, struct mock_iommu_domain, domain);
179         unsigned long flags = mock->flags;
180
181         if (enable && !domain->dirty_ops)
182                 return -EINVAL;
183
184         /* No change? */
185         if (!(enable ^ !!(flags & MOCK_DIRTY_TRACK)))
186                 return 0;
187
188         flags = (enable ? flags | MOCK_DIRTY_TRACK : flags & ~MOCK_DIRTY_TRACK);
189
190         mock->flags = flags;
191         return 0;
192 }
193
194 static int mock_domain_read_and_clear_dirty(struct iommu_domain *domain,
195                                             unsigned long iova, size_t size,
196                                             unsigned long flags,
197                                             struct iommu_dirty_bitmap *dirty)
198 {
199         struct mock_iommu_domain *mock =
200                 container_of(domain, struct mock_iommu_domain, domain);
201         unsigned long i, max = size / MOCK_IO_PAGE_SIZE;
202         void *ent, *old;
203
204         if (!(mock->flags & MOCK_DIRTY_TRACK) && dirty->bitmap)
205                 return -EINVAL;
206
207         for (i = 0; i < max; i++) {
208                 unsigned long cur = iova + i * MOCK_IO_PAGE_SIZE;
209
210                 ent = xa_load(&mock->pfns, cur / MOCK_IO_PAGE_SIZE);
211                 if (ent && (xa_to_value(ent) & MOCK_PFN_DIRTY_IOVA)) {
212                         /* Clear dirty */
213                         if (!(flags & IOMMU_DIRTY_NO_CLEAR)) {
214                                 unsigned long val;
215
216                                 val = xa_to_value(ent) & ~MOCK_PFN_DIRTY_IOVA;
217                                 old = xa_store(&mock->pfns,
218                                                cur / MOCK_IO_PAGE_SIZE,
219                                                xa_mk_value(val), GFP_KERNEL);
220                                 WARN_ON_ONCE(ent != old);
221                         }
222                         iommu_dirty_bitmap_record(dirty, cur,
223                                                   MOCK_IO_PAGE_SIZE);
224                 }
225         }
226
227         return 0;
228 }
229
230 const struct iommu_dirty_ops dirty_ops = {
231         .set_dirty_tracking = mock_domain_set_dirty_tracking,
232         .read_and_clear_dirty = mock_domain_read_and_clear_dirty,
233 };
234
235 static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
236 {
237         struct mock_iommu_domain *mock;
238
239         mock = kzalloc(sizeof(*mock), GFP_KERNEL);
240         if (!mock)
241                 return NULL;
242         mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
243         mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
244         mock->domain.pgsize_bitmap = MOCK_IO_PAGE_SIZE;
245         mock->domain.ops = mock_ops.default_domain_ops;
246         mock->domain.type = IOMMU_DOMAIN_UNMANAGED;
247         xa_init(&mock->pfns);
248         return &mock->domain;
249 }
250
251 static struct iommu_domain *
252 __mock_domain_alloc_nested(struct mock_iommu_domain *mock_parent,
253                            const struct iommu_hwpt_selftest *user_cfg)
254 {
255         struct mock_iommu_domain_nested *mock_nested;
256         int i;
257
258         mock_nested = kzalloc(sizeof(*mock_nested), GFP_KERNEL);
259         if (!mock_nested)
260                 return ERR_PTR(-ENOMEM);
261         mock_nested->parent = mock_parent;
262         mock_nested->domain.ops = &domain_nested_ops;
263         mock_nested->domain.type = IOMMU_DOMAIN_NESTED;
264         for (i = 0; i < MOCK_NESTED_DOMAIN_IOTLB_NUM; i++)
265                 mock_nested->iotlb[i] = user_cfg->iotlb;
266         return &mock_nested->domain;
267 }
268
269 static struct iommu_domain *
270 mock_domain_alloc_user(struct device *dev, u32 flags,
271                        struct iommu_domain *parent,
272                        const struct iommu_user_data *user_data)
273 {
274         struct mock_iommu_domain *mock_parent;
275         struct iommu_hwpt_selftest user_cfg;
276         int rc;
277
278         /* must be mock_domain */
279         if (!parent) {
280                 struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
281                 bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
282                 bool no_dirty_ops = mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY;
283                 struct iommu_domain *domain;
284
285                 if (flags & (~(IOMMU_HWPT_ALLOC_NEST_PARENT |
286                                IOMMU_HWPT_ALLOC_DIRTY_TRACKING)))
287                         return ERR_PTR(-EOPNOTSUPP);
288                 if (user_data || (has_dirty_flag && no_dirty_ops))
289                         return ERR_PTR(-EOPNOTSUPP);
290                 domain = mock_domain_alloc_paging(NULL);
291                 if (!domain)
292                         return ERR_PTR(-ENOMEM);
293                 if (has_dirty_flag)
294                         container_of(domain, struct mock_iommu_domain, domain)
295                                 ->domain.dirty_ops = &dirty_ops;
296                 return domain;
297         }
298
299         /* must be mock_domain_nested */
300         if (user_data->type != IOMMU_HWPT_DATA_SELFTEST || flags)
301                 return ERR_PTR(-EOPNOTSUPP);
302         if (!parent || parent->ops != mock_ops.default_domain_ops)
303                 return ERR_PTR(-EINVAL);
304
305         mock_parent = container_of(parent, struct mock_iommu_domain, domain);
306         if (!mock_parent)
307                 return ERR_PTR(-EINVAL);
308
309         rc = iommu_copy_struct_from_user(&user_cfg, user_data,
310                                          IOMMU_HWPT_DATA_SELFTEST, iotlb);
311         if (rc)
312                 return ERR_PTR(rc);
313
314         return __mock_domain_alloc_nested(mock_parent, &user_cfg);
315 }
316
317 static void mock_domain_free(struct iommu_domain *domain)
318 {
319         struct mock_iommu_domain *mock =
320                 container_of(domain, struct mock_iommu_domain, domain);
321
322         WARN_ON(!xa_empty(&mock->pfns));
323         kfree(mock);
324 }
325
326 static int mock_domain_map_pages(struct iommu_domain *domain,
327                                  unsigned long iova, phys_addr_t paddr,
328                                  size_t pgsize, size_t pgcount, int prot,
329                                  gfp_t gfp, size_t *mapped)
330 {
331         struct mock_iommu_domain *mock =
332                 container_of(domain, struct mock_iommu_domain, domain);
333         unsigned long flags = MOCK_PFN_START_IOVA;
334         unsigned long start_iova = iova;
335
336         /*
337          * xarray does not reliably work with fault injection because it does a
338          * retry allocation, so put our own failure point.
339          */
340         if (iommufd_should_fail())
341                 return -ENOENT;
342
343         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
344         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
345         for (; pgcount; pgcount--) {
346                 size_t cur;
347
348                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
349                         void *old;
350
351                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
352                                 flags = MOCK_PFN_LAST_IOVA;
353                         old = xa_store(&mock->pfns, iova / MOCK_IO_PAGE_SIZE,
354                                        xa_mk_value((paddr / MOCK_IO_PAGE_SIZE) |
355                                                    flags),
356                                        gfp);
357                         if (xa_is_err(old)) {
358                                 for (; start_iova != iova;
359                                      start_iova += MOCK_IO_PAGE_SIZE)
360                                         xa_erase(&mock->pfns,
361                                                  start_iova /
362                                                          MOCK_IO_PAGE_SIZE);
363                                 return xa_err(old);
364                         }
365                         WARN_ON(old);
366                         iova += MOCK_IO_PAGE_SIZE;
367                         paddr += MOCK_IO_PAGE_SIZE;
368                         *mapped += MOCK_IO_PAGE_SIZE;
369                         flags = 0;
370                 }
371         }
372         return 0;
373 }
374
375 static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
376                                       unsigned long iova, size_t pgsize,
377                                       size_t pgcount,
378                                       struct iommu_iotlb_gather *iotlb_gather)
379 {
380         struct mock_iommu_domain *mock =
381                 container_of(domain, struct mock_iommu_domain, domain);
382         bool first = true;
383         size_t ret = 0;
384         void *ent;
385
386         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
387         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
388
389         for (; pgcount; pgcount--) {
390                 size_t cur;
391
392                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
393                         ent = xa_erase(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
394
395                         /*
396                          * iommufd generates unmaps that must be a strict
397                          * superset of the map's performend So every starting
398                          * IOVA should have been an iova passed to map, and the
399                          *
400                          * First IOVA must be present and have been a first IOVA
401                          * passed to map_pages
402                          */
403                         if (first) {
404                                 WARN_ON(ent && !(xa_to_value(ent) &
405                                                  MOCK_PFN_START_IOVA));
406                                 first = false;
407                         }
408                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
409                                 WARN_ON(ent && !(xa_to_value(ent) &
410                                                  MOCK_PFN_LAST_IOVA));
411
412                         iova += MOCK_IO_PAGE_SIZE;
413                         ret += MOCK_IO_PAGE_SIZE;
414                 }
415         }
416         return ret;
417 }
418
419 static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
420                                             dma_addr_t iova)
421 {
422         struct mock_iommu_domain *mock =
423                 container_of(domain, struct mock_iommu_domain, domain);
424         void *ent;
425
426         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
427         ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
428         WARN_ON(!ent);
429         return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE;
430 }
431
432 static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
433 {
434         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
435
436         switch (cap) {
437         case IOMMU_CAP_CACHE_COHERENCY:
438                 return true;
439         case IOMMU_CAP_DIRTY_TRACKING:
440                 return !(mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY);
441         default:
442                 break;
443         }
444
445         return false;
446 }
447
448 static struct iommu_device mock_iommu_device = {
449 };
450
451 static struct iommu_device *mock_probe_device(struct device *dev)
452 {
453         if (dev->bus != &iommufd_mock_bus_type.bus)
454                 return ERR_PTR(-ENODEV);
455         return &mock_iommu_device;
456 }
457
458 static const struct iommu_ops mock_ops = {
459         /*
460          * IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
461          * because it is zero.
462          */
463         .default_domain = &mock_blocking_domain,
464         .blocked_domain = &mock_blocking_domain,
465         .owner = THIS_MODULE,
466         .pgsize_bitmap = MOCK_IO_PAGE_SIZE,
467         .hw_info = mock_domain_hw_info,
468         .domain_alloc_paging = mock_domain_alloc_paging,
469         .domain_alloc_user = mock_domain_alloc_user,
470         .capable = mock_domain_capable,
471         .device_group = generic_device_group,
472         .probe_device = mock_probe_device,
473         .default_domain_ops =
474                 &(struct iommu_domain_ops){
475                         .free = mock_domain_free,
476                         .attach_dev = mock_domain_nop_attach,
477                         .map_pages = mock_domain_map_pages,
478                         .unmap_pages = mock_domain_unmap_pages,
479                         .iova_to_phys = mock_domain_iova_to_phys,
480                 },
481 };
482
483 static void mock_domain_free_nested(struct iommu_domain *domain)
484 {
485         struct mock_iommu_domain_nested *mock_nested =
486                 container_of(domain, struct mock_iommu_domain_nested, domain);
487
488         kfree(mock_nested);
489 }
490
491 static int
492 mock_domain_cache_invalidate_user(struct iommu_domain *domain,
493                                   struct iommu_user_data_array *array)
494 {
495         struct mock_iommu_domain_nested *mock_nested =
496                 container_of(domain, struct mock_iommu_domain_nested, domain);
497         struct iommu_hwpt_invalidate_selftest inv;
498         u32 processed = 0;
499         int i = 0, j;
500         int rc = 0;
501
502         if (array->type != IOMMU_HWPT_INVALIDATE_DATA_SELFTEST) {
503                 rc = -EINVAL;
504                 goto out;
505         }
506
507         for ( ; i < array->entry_num; i++) {
508                 rc = iommu_copy_struct_from_user_array(&inv, array,
509                                                        IOMMU_HWPT_INVALIDATE_DATA_SELFTEST,
510                                                        i, iotlb_id);
511                 if (rc)
512                         break;
513
514                 if (inv.flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
515                         rc = -EOPNOTSUPP;
516                         break;
517                 }
518
519                 if (inv.iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX) {
520                         rc = -EINVAL;
521                         break;
522                 }
523
524                 if (inv.flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
525                         /* Invalidate all mock iotlb entries and ignore iotlb_id */
526                         for (j = 0; j < MOCK_NESTED_DOMAIN_IOTLB_NUM; j++)
527                                 mock_nested->iotlb[j] = 0;
528                 } else {
529                         mock_nested->iotlb[inv.iotlb_id] = 0;
530                 }
531
532                 processed++;
533         }
534
535 out:
536         array->entry_num = processed;
537         return rc;
538 }
539
540 static struct iommu_domain_ops domain_nested_ops = {
541         .free = mock_domain_free_nested,
542         .attach_dev = mock_domain_nop_attach,
543         .cache_invalidate_user = mock_domain_cache_invalidate_user,
544 };
545
546 static inline struct iommufd_hw_pagetable *
547 __get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, u32 hwpt_type)
548 {
549         struct iommufd_object *obj;
550
551         obj = iommufd_get_object(ucmd->ictx, mockpt_id, hwpt_type);
552         if (IS_ERR(obj))
553                 return ERR_CAST(obj);
554         return container_of(obj, struct iommufd_hw_pagetable, obj);
555 }
556
557 static inline struct iommufd_hw_pagetable *
558 get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
559                  struct mock_iommu_domain **mock)
560 {
561         struct iommufd_hw_pagetable *hwpt;
562
563         hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_PAGING);
564         if (IS_ERR(hwpt))
565                 return hwpt;
566         if (hwpt->domain->type != IOMMU_DOMAIN_UNMANAGED ||
567             hwpt->domain->ops != mock_ops.default_domain_ops) {
568                 iommufd_put_object(ucmd->ictx, &hwpt->obj);
569                 return ERR_PTR(-EINVAL);
570         }
571         *mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
572         return hwpt;
573 }
574
575 static inline struct iommufd_hw_pagetable *
576 get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
577                         struct mock_iommu_domain_nested **mock_nested)
578 {
579         struct iommufd_hw_pagetable *hwpt;
580
581         hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_NESTED);
582         if (IS_ERR(hwpt))
583                 return hwpt;
584         if (hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
585             hwpt->domain->ops != &domain_nested_ops) {
586                 iommufd_put_object(ucmd->ictx, &hwpt->obj);
587                 return ERR_PTR(-EINVAL);
588         }
589         *mock_nested = container_of(hwpt->domain,
590                                     struct mock_iommu_domain_nested, domain);
591         return hwpt;
592 }
593
594 static void mock_dev_release(struct device *dev)
595 {
596         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
597
598         atomic_dec(&mock_dev_num);
599         kfree(mdev);
600 }
601
602 static struct mock_dev *mock_dev_create(unsigned long dev_flags)
603 {
604         struct mock_dev *mdev;
605         int rc;
606
607         if (dev_flags & ~(MOCK_FLAGS_DEVICE_NO_DIRTY))
608                 return ERR_PTR(-EINVAL);
609
610         mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
611         if (!mdev)
612                 return ERR_PTR(-ENOMEM);
613
614         device_initialize(&mdev->dev);
615         mdev->flags = dev_flags;
616         mdev->dev.release = mock_dev_release;
617         mdev->dev.bus = &iommufd_mock_bus_type.bus;
618
619         rc = dev_set_name(&mdev->dev, "iommufd_mock%u",
620                           atomic_inc_return(&mock_dev_num));
621         if (rc)
622                 goto err_put;
623
624         rc = device_add(&mdev->dev);
625         if (rc)
626                 goto err_put;
627         return mdev;
628
629 err_put:
630         put_device(&mdev->dev);
631         return ERR_PTR(rc);
632 }
633
634 static void mock_dev_destroy(struct mock_dev *mdev)
635 {
636         device_unregister(&mdev->dev);
637 }
638
639 bool iommufd_selftest_is_mock_dev(struct device *dev)
640 {
641         return dev->release == mock_dev_release;
642 }
643
644 /* Create an hw_pagetable with the mock domain so we can test the domain ops */
645 static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
646                                     struct iommu_test_cmd *cmd)
647 {
648         struct iommufd_device *idev;
649         struct selftest_obj *sobj;
650         u32 pt_id = cmd->id;
651         u32 dev_flags = 0;
652         u32 idev_id;
653         int rc;
654
655         sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
656         if (IS_ERR(sobj))
657                 return PTR_ERR(sobj);
658
659         sobj->idev.ictx = ucmd->ictx;
660         sobj->type = TYPE_IDEV;
661
662         if (cmd->op == IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS)
663                 dev_flags = cmd->mock_domain_flags.dev_flags;
664
665         sobj->idev.mock_dev = mock_dev_create(dev_flags);
666         if (IS_ERR(sobj->idev.mock_dev)) {
667                 rc = PTR_ERR(sobj->idev.mock_dev);
668                 goto out_sobj;
669         }
670
671         idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
672                                    &idev_id);
673         if (IS_ERR(idev)) {
674                 rc = PTR_ERR(idev);
675                 goto out_mdev;
676         }
677         sobj->idev.idev = idev;
678
679         rc = iommufd_device_attach(idev, &pt_id);
680         if (rc)
681                 goto out_unbind;
682
683         /* Userspace must destroy the device_id to destroy the object */
684         cmd->mock_domain.out_hwpt_id = pt_id;
685         cmd->mock_domain.out_stdev_id = sobj->obj.id;
686         cmd->mock_domain.out_idev_id = idev_id;
687         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
688         if (rc)
689                 goto out_detach;
690         iommufd_object_finalize(ucmd->ictx, &sobj->obj);
691         return 0;
692
693 out_detach:
694         iommufd_device_detach(idev);
695 out_unbind:
696         iommufd_device_unbind(idev);
697 out_mdev:
698         mock_dev_destroy(sobj->idev.mock_dev);
699 out_sobj:
700         iommufd_object_abort(ucmd->ictx, &sobj->obj);
701         return rc;
702 }
703
704 /* Replace the mock domain with a manually allocated hw_pagetable */
705 static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
706                                             unsigned int device_id, u32 pt_id,
707                                             struct iommu_test_cmd *cmd)
708 {
709         struct iommufd_object *dev_obj;
710         struct selftest_obj *sobj;
711         int rc;
712
713         /*
714          * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
715          * it doesn't race with detach, which is not allowed.
716          */
717         dev_obj =
718                 iommufd_get_object(ucmd->ictx, device_id, IOMMUFD_OBJ_SELFTEST);
719         if (IS_ERR(dev_obj))
720                 return PTR_ERR(dev_obj);
721
722         sobj = container_of(dev_obj, struct selftest_obj, obj);
723         if (sobj->type != TYPE_IDEV) {
724                 rc = -EINVAL;
725                 goto out_dev_obj;
726         }
727
728         rc = iommufd_device_replace(sobj->idev.idev, &pt_id);
729         if (rc)
730                 goto out_dev_obj;
731
732         cmd->mock_domain_replace.pt_id = pt_id;
733         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
734
735 out_dev_obj:
736         iommufd_put_object(ucmd->ictx, dev_obj);
737         return rc;
738 }
739
740 /* Add an additional reserved IOVA to the IOAS */
741 static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
742                                      unsigned int mockpt_id,
743                                      unsigned long start, size_t length)
744 {
745         struct iommufd_ioas *ioas;
746         int rc;
747
748         ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
749         if (IS_ERR(ioas))
750                 return PTR_ERR(ioas);
751         down_write(&ioas->iopt.iova_rwsem);
752         rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
753         up_write(&ioas->iopt.iova_rwsem);
754         iommufd_put_object(ucmd->ictx, &ioas->obj);
755         return rc;
756 }
757
758 /* Check that every pfn under each iova matches the pfn under a user VA */
759 static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
760                                     unsigned int mockpt_id, unsigned long iova,
761                                     size_t length, void __user *uptr)
762 {
763         struct iommufd_hw_pagetable *hwpt;
764         struct mock_iommu_domain *mock;
765         uintptr_t end;
766         int rc;
767
768         if (iova % MOCK_IO_PAGE_SIZE || length % MOCK_IO_PAGE_SIZE ||
769             (uintptr_t)uptr % MOCK_IO_PAGE_SIZE ||
770             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
771                 return -EINVAL;
772
773         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
774         if (IS_ERR(hwpt))
775                 return PTR_ERR(hwpt);
776
777         for (; length; length -= MOCK_IO_PAGE_SIZE) {
778                 struct page *pages[1];
779                 unsigned long pfn;
780                 long npages;
781                 void *ent;
782
783                 npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
784                                              pages);
785                 if (npages < 0) {
786                         rc = npages;
787                         goto out_put;
788                 }
789                 if (WARN_ON(npages != 1)) {
790                         rc = -EFAULT;
791                         goto out_put;
792                 }
793                 pfn = page_to_pfn(pages[0]);
794                 put_page(pages[0]);
795
796                 ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
797                 if (!ent ||
798                     (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE !=
799                             pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
800                         rc = -EINVAL;
801                         goto out_put;
802                 }
803                 iova += MOCK_IO_PAGE_SIZE;
804                 uptr += MOCK_IO_PAGE_SIZE;
805         }
806         rc = 0;
807
808 out_put:
809         iommufd_put_object(ucmd->ictx, &hwpt->obj);
810         return rc;
811 }
812
813 /* Check that the page ref count matches, to look for missing pin/unpins */
814 static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
815                                       void __user *uptr, size_t length,
816                                       unsigned int refs)
817 {
818         uintptr_t end;
819
820         if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
821             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
822                 return -EINVAL;
823
824         for (; length; length -= PAGE_SIZE) {
825                 struct page *pages[1];
826                 long npages;
827
828                 npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
829                 if (npages < 0)
830                         return npages;
831                 if (WARN_ON(npages != 1))
832                         return -EFAULT;
833                 if (!PageCompound(pages[0])) {
834                         unsigned int count;
835
836                         count = page_ref_count(pages[0]);
837                         if (count / GUP_PIN_COUNTING_BIAS != refs) {
838                                 put_page(pages[0]);
839                                 return -EIO;
840                         }
841                 }
842                 put_page(pages[0]);
843                 uptr += PAGE_SIZE;
844         }
845         return 0;
846 }
847
848 static int iommufd_test_md_check_iotlb(struct iommufd_ucmd *ucmd,
849                                        u32 mockpt_id, unsigned int iotlb_id,
850                                        u32 iotlb)
851 {
852         struct mock_iommu_domain_nested *mock_nested;
853         struct iommufd_hw_pagetable *hwpt;
854         int rc = 0;
855
856         hwpt = get_md_pagetable_nested(ucmd, mockpt_id, &mock_nested);
857         if (IS_ERR(hwpt))
858                 return PTR_ERR(hwpt);
859
860         mock_nested = container_of(hwpt->domain,
861                                    struct mock_iommu_domain_nested, domain);
862
863         if (iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX ||
864             mock_nested->iotlb[iotlb_id] != iotlb)
865                 rc = -EINVAL;
866         iommufd_put_object(ucmd->ictx, &hwpt->obj);
867         return rc;
868 }
869
870 struct selftest_access {
871         struct iommufd_access *access;
872         struct file *file;
873         struct mutex lock;
874         struct list_head items;
875         unsigned int next_id;
876         bool destroying;
877 };
878
879 struct selftest_access_item {
880         struct list_head items_elm;
881         unsigned long iova;
882         size_t length;
883         unsigned int id;
884 };
885
886 static const struct file_operations iommfd_test_staccess_fops;
887
888 static struct selftest_access *iommufd_access_get(int fd)
889 {
890         struct file *file;
891
892         file = fget(fd);
893         if (!file)
894                 return ERR_PTR(-EBADFD);
895
896         if (file->f_op != &iommfd_test_staccess_fops) {
897                 fput(file);
898                 return ERR_PTR(-EBADFD);
899         }
900         return file->private_data;
901 }
902
903 static void iommufd_test_access_unmap(void *data, unsigned long iova,
904                                       unsigned long length)
905 {
906         unsigned long iova_last = iova + length - 1;
907         struct selftest_access *staccess = data;
908         struct selftest_access_item *item;
909         struct selftest_access_item *tmp;
910
911         mutex_lock(&staccess->lock);
912         list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
913                 if (iova > item->iova + item->length - 1 ||
914                     iova_last < item->iova)
915                         continue;
916                 list_del(&item->items_elm);
917                 iommufd_access_unpin_pages(staccess->access, item->iova,
918                                            item->length);
919                 kfree(item);
920         }
921         mutex_unlock(&staccess->lock);
922 }
923
924 static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
925                                             unsigned int access_id,
926                                             unsigned int item_id)
927 {
928         struct selftest_access_item *item;
929         struct selftest_access *staccess;
930
931         staccess = iommufd_access_get(access_id);
932         if (IS_ERR(staccess))
933                 return PTR_ERR(staccess);
934
935         mutex_lock(&staccess->lock);
936         list_for_each_entry(item, &staccess->items, items_elm) {
937                 if (item->id == item_id) {
938                         list_del(&item->items_elm);
939                         iommufd_access_unpin_pages(staccess->access, item->iova,
940                                                    item->length);
941                         mutex_unlock(&staccess->lock);
942                         kfree(item);
943                         fput(staccess->file);
944                         return 0;
945                 }
946         }
947         mutex_unlock(&staccess->lock);
948         fput(staccess->file);
949         return -ENOENT;
950 }
951
952 static int iommufd_test_staccess_release(struct inode *inode,
953                                          struct file *filep)
954 {
955         struct selftest_access *staccess = filep->private_data;
956
957         if (staccess->access) {
958                 iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
959                 iommufd_access_destroy(staccess->access);
960         }
961         mutex_destroy(&staccess->lock);
962         kfree(staccess);
963         return 0;
964 }
965
966 static const struct iommufd_access_ops selftest_access_ops_pin = {
967         .needs_pin_pages = 1,
968         .unmap = iommufd_test_access_unmap,
969 };
970
971 static const struct iommufd_access_ops selftest_access_ops = {
972         .unmap = iommufd_test_access_unmap,
973 };
974
975 static const struct file_operations iommfd_test_staccess_fops = {
976         .release = iommufd_test_staccess_release,
977 };
978
979 static struct selftest_access *iommufd_test_alloc_access(void)
980 {
981         struct selftest_access *staccess;
982         struct file *filep;
983
984         staccess = kzalloc(sizeof(*staccess), GFP_KERNEL_ACCOUNT);
985         if (!staccess)
986                 return ERR_PTR(-ENOMEM);
987         INIT_LIST_HEAD(&staccess->items);
988         mutex_init(&staccess->lock);
989
990         filep = anon_inode_getfile("[iommufd_test_staccess]",
991                                    &iommfd_test_staccess_fops, staccess,
992                                    O_RDWR);
993         if (IS_ERR(filep)) {
994                 kfree(staccess);
995                 return ERR_CAST(filep);
996         }
997         staccess->file = filep;
998         return staccess;
999 }
1000
1001 static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
1002                                       unsigned int ioas_id, unsigned int flags)
1003 {
1004         struct iommu_test_cmd *cmd = ucmd->cmd;
1005         struct selftest_access *staccess;
1006         struct iommufd_access *access;
1007         u32 id;
1008         int fdno;
1009         int rc;
1010
1011         if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
1012                 return -EOPNOTSUPP;
1013
1014         staccess = iommufd_test_alloc_access();
1015         if (IS_ERR(staccess))
1016                 return PTR_ERR(staccess);
1017
1018         fdno = get_unused_fd_flags(O_CLOEXEC);
1019         if (fdno < 0) {
1020                 rc = -ENOMEM;
1021                 goto out_free_staccess;
1022         }
1023
1024         access = iommufd_access_create(
1025                 ucmd->ictx,
1026                 (flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
1027                         &selftest_access_ops_pin :
1028                         &selftest_access_ops,
1029                 staccess, &id);
1030         if (IS_ERR(access)) {
1031                 rc = PTR_ERR(access);
1032                 goto out_put_fdno;
1033         }
1034         rc = iommufd_access_attach(access, ioas_id);
1035         if (rc)
1036                 goto out_destroy;
1037         cmd->create_access.out_access_fd = fdno;
1038         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1039         if (rc)
1040                 goto out_destroy;
1041
1042         staccess->access = access;
1043         fd_install(fdno, staccess->file);
1044         return 0;
1045
1046 out_destroy:
1047         iommufd_access_destroy(access);
1048 out_put_fdno:
1049         put_unused_fd(fdno);
1050 out_free_staccess:
1051         fput(staccess->file);
1052         return rc;
1053 }
1054
1055 static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
1056                                             unsigned int access_id,
1057                                             unsigned int ioas_id)
1058 {
1059         struct selftest_access *staccess;
1060         int rc;
1061
1062         staccess = iommufd_access_get(access_id);
1063         if (IS_ERR(staccess))
1064                 return PTR_ERR(staccess);
1065
1066         rc = iommufd_access_replace(staccess->access, ioas_id);
1067         fput(staccess->file);
1068         return rc;
1069 }
1070
1071 /* Check that the pages in a page array match the pages in the user VA */
1072 static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
1073                                     size_t npages)
1074 {
1075         for (; npages; npages--) {
1076                 struct page *tmp_pages[1];
1077                 long rc;
1078
1079                 rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
1080                 if (rc < 0)
1081                         return rc;
1082                 if (WARN_ON(rc != 1))
1083                         return -EFAULT;
1084                 put_page(tmp_pages[0]);
1085                 if (tmp_pages[0] != *pages)
1086                         return -EBADE;
1087                 pages++;
1088                 uptr += PAGE_SIZE;
1089         }
1090         return 0;
1091 }
1092
1093 static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
1094                                      unsigned int access_id, unsigned long iova,
1095                                      size_t length, void __user *uptr,
1096                                      u32 flags)
1097 {
1098         struct iommu_test_cmd *cmd = ucmd->cmd;
1099         struct selftest_access_item *item;
1100         struct selftest_access *staccess;
1101         struct page **pages;
1102         size_t npages;
1103         int rc;
1104
1105         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1106         if (length > 16*1024*1024)
1107                 return -ENOMEM;
1108
1109         if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
1110                 return -EOPNOTSUPP;
1111
1112         staccess = iommufd_access_get(access_id);
1113         if (IS_ERR(staccess))
1114                 return PTR_ERR(staccess);
1115
1116         if (staccess->access->ops != &selftest_access_ops_pin) {
1117                 rc = -EOPNOTSUPP;
1118                 goto out_put;
1119         }
1120
1121         if (flags & MOCK_FLAGS_ACCESS_SYZ)
1122                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
1123                                         &cmd->access_pages.iova);
1124
1125         npages = (ALIGN(iova + length, PAGE_SIZE) -
1126                   ALIGN_DOWN(iova, PAGE_SIZE)) /
1127                  PAGE_SIZE;
1128         pages = kvcalloc(npages, sizeof(*pages), GFP_KERNEL_ACCOUNT);
1129         if (!pages) {
1130                 rc = -ENOMEM;
1131                 goto out_put;
1132         }
1133
1134         /*
1135          * Drivers will need to think very carefully about this locking. The
1136          * core code can do multiple unmaps instantaneously after
1137          * iommufd_access_pin_pages() and *all* the unmaps must not return until
1138          * the range is unpinned. This simple implementation puts a global lock
1139          * around the pin, which may not suit drivers that want this to be a
1140          * performance path. drivers that get this wrong will trigger WARN_ON
1141          * races and cause EDEADLOCK failures to userspace.
1142          */
1143         mutex_lock(&staccess->lock);
1144         rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
1145                                       flags & MOCK_FLAGS_ACCESS_WRITE);
1146         if (rc)
1147                 goto out_unlock;
1148
1149         /* For syzkaller allow uptr to be NULL to skip this check */
1150         if (uptr) {
1151                 rc = iommufd_test_check_pages(
1152                         uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
1153                         npages);
1154                 if (rc)
1155                         goto out_unaccess;
1156         }
1157
1158         item = kzalloc(sizeof(*item), GFP_KERNEL_ACCOUNT);
1159         if (!item) {
1160                 rc = -ENOMEM;
1161                 goto out_unaccess;
1162         }
1163
1164         item->iova = iova;
1165         item->length = length;
1166         item->id = staccess->next_id++;
1167         list_add_tail(&item->items_elm, &staccess->items);
1168
1169         cmd->access_pages.out_access_pages_id = item->id;
1170         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1171         if (rc)
1172                 goto out_free_item;
1173         goto out_unlock;
1174
1175 out_free_item:
1176         list_del(&item->items_elm);
1177         kfree(item);
1178 out_unaccess:
1179         iommufd_access_unpin_pages(staccess->access, iova, length);
1180 out_unlock:
1181         mutex_unlock(&staccess->lock);
1182         kvfree(pages);
1183 out_put:
1184         fput(staccess->file);
1185         return rc;
1186 }
1187
1188 static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
1189                                   unsigned int access_id, unsigned long iova,
1190                                   size_t length, void __user *ubuf,
1191                                   unsigned int flags)
1192 {
1193         struct iommu_test_cmd *cmd = ucmd->cmd;
1194         struct selftest_access *staccess;
1195         void *tmp;
1196         int rc;
1197
1198         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1199         if (length > 16*1024*1024)
1200                 return -ENOMEM;
1201
1202         if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
1203                       MOCK_FLAGS_ACCESS_SYZ))
1204                 return -EOPNOTSUPP;
1205
1206         staccess = iommufd_access_get(access_id);
1207         if (IS_ERR(staccess))
1208                 return PTR_ERR(staccess);
1209
1210         tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
1211         if (!tmp) {
1212                 rc = -ENOMEM;
1213                 goto out_put;
1214         }
1215
1216         if (flags & MOCK_ACCESS_RW_WRITE) {
1217                 if (copy_from_user(tmp, ubuf, length)) {
1218                         rc = -EFAULT;
1219                         goto out_free;
1220                 }
1221         }
1222
1223         if (flags & MOCK_FLAGS_ACCESS_SYZ)
1224                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
1225                                         &cmd->access_rw.iova);
1226
1227         rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
1228         if (rc)
1229                 goto out_free;
1230         if (!(flags & MOCK_ACCESS_RW_WRITE)) {
1231                 if (copy_to_user(ubuf, tmp, length)) {
1232                         rc = -EFAULT;
1233                         goto out_free;
1234                 }
1235         }
1236
1237 out_free:
1238         kvfree(tmp);
1239 out_put:
1240         fput(staccess->file);
1241         return rc;
1242 }
1243 static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
1244 static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
1245               __IOMMUFD_ACCESS_RW_SLOW_PATH);
1246
1247 static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
1248                               unsigned long iova, size_t length,
1249                               unsigned long page_size, void __user *uptr,
1250                               u32 flags)
1251 {
1252         unsigned long bitmap_size, i, max;
1253         struct iommu_test_cmd *cmd = ucmd->cmd;
1254         struct iommufd_hw_pagetable *hwpt;
1255         struct mock_iommu_domain *mock;
1256         int rc, count = 0;
1257         void *tmp;
1258
1259         if (!page_size || !length || iova % page_size || length % page_size ||
1260             !uptr)
1261                 return -EINVAL;
1262
1263         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
1264         if (IS_ERR(hwpt))
1265                 return PTR_ERR(hwpt);
1266
1267         if (!(mock->flags & MOCK_DIRTY_TRACK)) {
1268                 rc = -EINVAL;
1269                 goto out_put;
1270         }
1271
1272         max = length / page_size;
1273         bitmap_size = max / BITS_PER_BYTE;
1274
1275         tmp = kvzalloc(bitmap_size, GFP_KERNEL_ACCOUNT);
1276         if (!tmp) {
1277                 rc = -ENOMEM;
1278                 goto out_put;
1279         }
1280
1281         if (copy_from_user(tmp, uptr, bitmap_size)) {
1282                 rc = -EFAULT;
1283                 goto out_free;
1284         }
1285
1286         for (i = 0; i < max; i++) {
1287                 unsigned long cur = iova + i * page_size;
1288                 void *ent, *old;
1289
1290                 if (!test_bit(i, (unsigned long *)tmp))
1291                         continue;
1292
1293                 ent = xa_load(&mock->pfns, cur / page_size);
1294                 if (ent) {
1295                         unsigned long val;
1296
1297                         val = xa_to_value(ent) | MOCK_PFN_DIRTY_IOVA;
1298                         old = xa_store(&mock->pfns, cur / page_size,
1299                                        xa_mk_value(val), GFP_KERNEL);
1300                         WARN_ON_ONCE(ent != old);
1301                         count++;
1302                 }
1303         }
1304
1305         cmd->dirty.out_nr_dirty = count;
1306         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1307 out_free:
1308         kvfree(tmp);
1309 out_put:
1310         iommufd_put_object(ucmd->ictx, &hwpt->obj);
1311         return rc;
1312 }
1313
1314 void iommufd_selftest_destroy(struct iommufd_object *obj)
1315 {
1316         struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
1317
1318         switch (sobj->type) {
1319         case TYPE_IDEV:
1320                 iommufd_device_detach(sobj->idev.idev);
1321                 iommufd_device_unbind(sobj->idev.idev);
1322                 mock_dev_destroy(sobj->idev.mock_dev);
1323                 break;
1324         }
1325 }
1326
1327 int iommufd_test(struct iommufd_ucmd *ucmd)
1328 {
1329         struct iommu_test_cmd *cmd = ucmd->cmd;
1330
1331         switch (cmd->op) {
1332         case IOMMU_TEST_OP_ADD_RESERVED:
1333                 return iommufd_test_add_reserved(ucmd, cmd->id,
1334                                                  cmd->add_reserved.start,
1335                                                  cmd->add_reserved.length);
1336         case IOMMU_TEST_OP_MOCK_DOMAIN:
1337         case IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS:
1338                 return iommufd_test_mock_domain(ucmd, cmd);
1339         case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
1340                 return iommufd_test_mock_domain_replace(
1341                         ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
1342         case IOMMU_TEST_OP_MD_CHECK_MAP:
1343                 return iommufd_test_md_check_pa(
1344                         ucmd, cmd->id, cmd->check_map.iova,
1345                         cmd->check_map.length,
1346                         u64_to_user_ptr(cmd->check_map.uptr));
1347         case IOMMU_TEST_OP_MD_CHECK_REFS:
1348                 return iommufd_test_md_check_refs(
1349                         ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
1350                         cmd->check_refs.length, cmd->check_refs.refs);
1351         case IOMMU_TEST_OP_MD_CHECK_IOTLB:
1352                 return iommufd_test_md_check_iotlb(ucmd, cmd->id,
1353                                                    cmd->check_iotlb.id,
1354                                                    cmd->check_iotlb.iotlb);
1355         case IOMMU_TEST_OP_CREATE_ACCESS:
1356                 return iommufd_test_create_access(ucmd, cmd->id,
1357                                                   cmd->create_access.flags);
1358         case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
1359                 return iommufd_test_access_replace_ioas(
1360                         ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
1361         case IOMMU_TEST_OP_ACCESS_PAGES:
1362                 return iommufd_test_access_pages(
1363                         ucmd, cmd->id, cmd->access_pages.iova,
1364                         cmd->access_pages.length,
1365                         u64_to_user_ptr(cmd->access_pages.uptr),
1366                         cmd->access_pages.flags);
1367         case IOMMU_TEST_OP_ACCESS_RW:
1368                 return iommufd_test_access_rw(
1369                         ucmd, cmd->id, cmd->access_rw.iova,
1370                         cmd->access_rw.length,
1371                         u64_to_user_ptr(cmd->access_rw.uptr),
1372                         cmd->access_rw.flags);
1373         case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
1374                 return iommufd_test_access_item_destroy(
1375                         ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
1376         case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
1377                 /* Protect _batch_init(), can not be less than elmsz */
1378                 if (cmd->memory_limit.limit <
1379                     sizeof(unsigned long) + sizeof(u32))
1380                         return -EINVAL;
1381                 iommufd_test_memory_limit = cmd->memory_limit.limit;
1382                 return 0;
1383         case IOMMU_TEST_OP_DIRTY:
1384                 return iommufd_test_dirty(ucmd, cmd->id, cmd->dirty.iova,
1385                                           cmd->dirty.length,
1386                                           cmd->dirty.page_size,
1387                                           u64_to_user_ptr(cmd->dirty.uptr),
1388                                           cmd->dirty.flags);
1389         default:
1390                 return -EOPNOTSUPP;
1391         }
1392 }
1393
1394 bool iommufd_should_fail(void)
1395 {
1396         return should_fail(&fail_iommufd, 1);
1397 }
1398
1399 int __init iommufd_test_init(void)
1400 {
1401         struct platform_device_info pdevinfo = {
1402                 .name = "iommufd_selftest_iommu",
1403         };
1404         int rc;
1405
1406         dbgfs_root =
1407                 fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);
1408
1409         selftest_iommu_dev = platform_device_register_full(&pdevinfo);
1410         if (IS_ERR(selftest_iommu_dev)) {
1411                 rc = PTR_ERR(selftest_iommu_dev);
1412                 goto err_dbgfs;
1413         }
1414
1415         rc = bus_register(&iommufd_mock_bus_type.bus);
1416         if (rc)
1417                 goto err_platform;
1418
1419         rc = iommu_device_sysfs_add(&mock_iommu_device,
1420                                     &selftest_iommu_dev->dev, NULL, "%s",
1421                                     dev_name(&selftest_iommu_dev->dev));
1422         if (rc)
1423                 goto err_bus;
1424
1425         rc = iommu_device_register_bus(&mock_iommu_device, &mock_ops,
1426                                   &iommufd_mock_bus_type.bus,
1427                                   &iommufd_mock_bus_type.nb);
1428         if (rc)
1429                 goto err_sysfs;
1430         return 0;
1431
1432 err_sysfs:
1433         iommu_device_sysfs_remove(&mock_iommu_device);
1434 err_bus:
1435         bus_unregister(&iommufd_mock_bus_type.bus);
1436 err_platform:
1437         platform_device_unregister(selftest_iommu_dev);
1438 err_dbgfs:
1439         debugfs_remove_recursive(dbgfs_root);
1440         return rc;
1441 }
1442
1443 void iommufd_test_exit(void)
1444 {
1445         iommu_device_sysfs_remove(&mock_iommu_device);
1446         iommu_device_unregister_bus(&mock_iommu_device,
1447                                     &iommufd_mock_bus_type.bus,
1448                                     &iommufd_mock_bus_type.nb);
1449         bus_unregister(&iommufd_mock_bus_type.bus);
1450         platform_device_unregister(selftest_iommu_dev);
1451         debugfs_remove_recursive(dbgfs_root);
1452 }