Merge tag 'iommu-updates-v6.7' of git://git.kernel.org/pub/scm/linux/kernel/git/joro...
[sfrench/cifs-2.6.git] / drivers / iommu / iommufd / selftest.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * Kernel side components to support tools/testing/selftests/iommu
5  */
6 #include <linux/slab.h>
7 #include <linux/iommu.h>
8 #include <linux/xarray.h>
9 #include <linux/file.h>
10 #include <linux/anon_inodes.h>
11 #include <linux/fault-inject.h>
12 #include <linux/platform_device.h>
13 #include <uapi/linux/iommufd.h>
14
15 #include "../iommu-priv.h"
16 #include "io_pagetable.h"
17 #include "iommufd_private.h"
18 #include "iommufd_test.h"
19
20 static DECLARE_FAULT_ATTR(fail_iommufd);
21 static struct dentry *dbgfs_root;
22 static struct platform_device *selftest_iommu_dev;
23 static const struct iommu_ops mock_ops;
24 static struct iommu_domain_ops domain_nested_ops;
25
26 size_t iommufd_test_memory_limit = 65536;
27
28 enum {
29         MOCK_DIRTY_TRACK = 1,
30         MOCK_IO_PAGE_SIZE = PAGE_SIZE / 2,
31
32         /*
33          * Like a real page table alignment requires the low bits of the address
34          * to be zero. xarray also requires the high bit to be zero, so we store
35          * the pfns shifted. The upper bits are used for metadata.
36          */
37         MOCK_PFN_MASK = ULONG_MAX / MOCK_IO_PAGE_SIZE,
38
39         _MOCK_PFN_START = MOCK_PFN_MASK + 1,
40         MOCK_PFN_START_IOVA = _MOCK_PFN_START,
41         MOCK_PFN_LAST_IOVA = _MOCK_PFN_START,
42         MOCK_PFN_DIRTY_IOVA = _MOCK_PFN_START << 1,
43 };
44
45 /*
46  * Syzkaller has trouble randomizing the correct iova to use since it is linked
47  * to the map ioctl's output, and it has no ide about that. So, simplify things.
48  * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
49  * value. This has a much smaller randomization space and syzkaller can hit it.
50  */
51 static unsigned long iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
52                                                 u64 *iova)
53 {
54         struct syz_layout {
55                 __u32 nth_area;
56                 __u32 offset;
57         };
58         struct syz_layout *syz = (void *)iova;
59         unsigned int nth = syz->nth_area;
60         struct iopt_area *area;
61
62         down_read(&iopt->iova_rwsem);
63         for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
64              area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
65                 if (nth == 0) {
66                         up_read(&iopt->iova_rwsem);
67                         return iopt_area_iova(area) + syz->offset;
68                 }
69                 nth--;
70         }
71         up_read(&iopt->iova_rwsem);
72
73         return 0;
74 }
75
76 void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
77                                    unsigned int ioas_id, u64 *iova, u32 *flags)
78 {
79         struct iommufd_ioas *ioas;
80
81         if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
82                 return;
83         *flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;
84
85         ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
86         if (IS_ERR(ioas))
87                 return;
88         *iova = iommufd_test_syz_conv_iova(&ioas->iopt, iova);
89         iommufd_put_object(&ioas->obj);
90 }
91
92 struct mock_iommu_domain {
93         unsigned long flags;
94         struct iommu_domain domain;
95         struct xarray pfns;
96 };
97
98 struct mock_iommu_domain_nested {
99         struct iommu_domain domain;
100         struct mock_iommu_domain *parent;
101         u32 iotlb[MOCK_NESTED_DOMAIN_IOTLB_NUM];
102 };
103
104 enum selftest_obj_type {
105         TYPE_IDEV,
106 };
107
108 struct mock_dev {
109         struct device dev;
110         unsigned long flags;
111 };
112
113 struct selftest_obj {
114         struct iommufd_object obj;
115         enum selftest_obj_type type;
116
117         union {
118                 struct {
119                         struct iommufd_device *idev;
120                         struct iommufd_ctx *ictx;
121                         struct mock_dev *mock_dev;
122                 } idev;
123         };
124 };
125
126 static int mock_domain_nop_attach(struct iommu_domain *domain,
127                                   struct device *dev)
128 {
129         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
130
131         if (domain->dirty_ops && (mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY))
132                 return -EINVAL;
133
134         return 0;
135 }
136
137 static const struct iommu_domain_ops mock_blocking_ops = {
138         .attach_dev = mock_domain_nop_attach,
139 };
140
141 static struct iommu_domain mock_blocking_domain = {
142         .type = IOMMU_DOMAIN_BLOCKED,
143         .ops = &mock_blocking_ops,
144 };
145
146 static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
147 {
148         struct iommu_test_hw_info *info;
149
150         info = kzalloc(sizeof(*info), GFP_KERNEL);
151         if (!info)
152                 return ERR_PTR(-ENOMEM);
153
154         info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
155         *length = sizeof(*info);
156         *type = IOMMU_HW_INFO_TYPE_SELFTEST;
157
158         return info;
159 }
160
161 static int mock_domain_set_dirty_tracking(struct iommu_domain *domain,
162                                           bool enable)
163 {
164         struct mock_iommu_domain *mock =
165                 container_of(domain, struct mock_iommu_domain, domain);
166         unsigned long flags = mock->flags;
167
168         if (enable && !domain->dirty_ops)
169                 return -EINVAL;
170
171         /* No change? */
172         if (!(enable ^ !!(flags & MOCK_DIRTY_TRACK)))
173                 return 0;
174
175         flags = (enable ? flags | MOCK_DIRTY_TRACK : flags & ~MOCK_DIRTY_TRACK);
176
177         mock->flags = flags;
178         return 0;
179 }
180
181 static int mock_domain_read_and_clear_dirty(struct iommu_domain *domain,
182                                             unsigned long iova, size_t size,
183                                             unsigned long flags,
184                                             struct iommu_dirty_bitmap *dirty)
185 {
186         struct mock_iommu_domain *mock =
187                 container_of(domain, struct mock_iommu_domain, domain);
188         unsigned long i, max = size / MOCK_IO_PAGE_SIZE;
189         void *ent, *old;
190
191         if (!(mock->flags & MOCK_DIRTY_TRACK) && dirty->bitmap)
192                 return -EINVAL;
193
194         for (i = 0; i < max; i++) {
195                 unsigned long cur = iova + i * MOCK_IO_PAGE_SIZE;
196
197                 ent = xa_load(&mock->pfns, cur / MOCK_IO_PAGE_SIZE);
198                 if (ent && (xa_to_value(ent) & MOCK_PFN_DIRTY_IOVA)) {
199                         /* Clear dirty */
200                         if (!(flags & IOMMU_DIRTY_NO_CLEAR)) {
201                                 unsigned long val;
202
203                                 val = xa_to_value(ent) & ~MOCK_PFN_DIRTY_IOVA;
204                                 old = xa_store(&mock->pfns,
205                                                cur / MOCK_IO_PAGE_SIZE,
206                                                xa_mk_value(val), GFP_KERNEL);
207                                 WARN_ON_ONCE(ent != old);
208                         }
209                         iommu_dirty_bitmap_record(dirty, cur,
210                                                   MOCK_IO_PAGE_SIZE);
211                 }
212         }
213
214         return 0;
215 }
216
217 const struct iommu_dirty_ops dirty_ops = {
218         .set_dirty_tracking = mock_domain_set_dirty_tracking,
219         .read_and_clear_dirty = mock_domain_read_and_clear_dirty,
220 };
221
222 static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
223 {
224         struct mock_iommu_domain *mock;
225
226         mock = kzalloc(sizeof(*mock), GFP_KERNEL);
227         if (!mock)
228                 return NULL;
229         mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
230         mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
231         mock->domain.pgsize_bitmap = MOCK_IO_PAGE_SIZE;
232         mock->domain.ops = mock_ops.default_domain_ops;
233         mock->domain.type = IOMMU_DOMAIN_UNMANAGED;
234         xa_init(&mock->pfns);
235         return &mock->domain;
236 }
237
238 static struct iommu_domain *
239 __mock_domain_alloc_nested(struct mock_iommu_domain *mock_parent,
240                            const struct iommu_hwpt_selftest *user_cfg)
241 {
242         struct mock_iommu_domain_nested *mock_nested;
243         int i;
244
245         mock_nested = kzalloc(sizeof(*mock_nested), GFP_KERNEL);
246         if (!mock_nested)
247                 return ERR_PTR(-ENOMEM);
248         mock_nested->parent = mock_parent;
249         mock_nested->domain.ops = &domain_nested_ops;
250         mock_nested->domain.type = IOMMU_DOMAIN_NESTED;
251         for (i = 0; i < MOCK_NESTED_DOMAIN_IOTLB_NUM; i++)
252                 mock_nested->iotlb[i] = user_cfg->iotlb;
253         return &mock_nested->domain;
254 }
255
256 static struct iommu_domain *
257 mock_domain_alloc_user(struct device *dev, u32 flags,
258                        struct iommu_domain *parent,
259                        const struct iommu_user_data *user_data)
260 {
261         struct mock_iommu_domain *mock_parent;
262         struct iommu_hwpt_selftest user_cfg;
263         int rc;
264
265         /* must be mock_domain */
266         if (!parent) {
267                 struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
268                 bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
269                 bool no_dirty_ops = mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY;
270                 struct iommu_domain *domain;
271
272                 if (flags & (~(IOMMU_HWPT_ALLOC_NEST_PARENT |
273                                IOMMU_HWPT_ALLOC_DIRTY_TRACKING)))
274                         return ERR_PTR(-EOPNOTSUPP);
275                 if (user_data || (has_dirty_flag && no_dirty_ops))
276                         return ERR_PTR(-EOPNOTSUPP);
277                 domain = mock_domain_alloc_paging(NULL);
278                 if (!domain)
279                         return ERR_PTR(-ENOMEM);
280                 if (has_dirty_flag)
281                         container_of(domain, struct mock_iommu_domain, domain)
282                                 ->domain.dirty_ops = &dirty_ops;
283                 return domain;
284         }
285
286         /* must be mock_domain_nested */
287         if (user_data->type != IOMMU_HWPT_DATA_SELFTEST || flags)
288                 return ERR_PTR(-EOPNOTSUPP);
289         if (!parent || parent->ops != mock_ops.default_domain_ops)
290                 return ERR_PTR(-EINVAL);
291
292         mock_parent = container_of(parent, struct mock_iommu_domain, domain);
293         if (!mock_parent)
294                 return ERR_PTR(-EINVAL);
295
296         rc = iommu_copy_struct_from_user(&user_cfg, user_data,
297                                          IOMMU_HWPT_DATA_SELFTEST, iotlb);
298         if (rc)
299                 return ERR_PTR(rc);
300
301         return __mock_domain_alloc_nested(mock_parent, &user_cfg);
302 }
303
304 static void mock_domain_free(struct iommu_domain *domain)
305 {
306         struct mock_iommu_domain *mock =
307                 container_of(domain, struct mock_iommu_domain, domain);
308
309         WARN_ON(!xa_empty(&mock->pfns));
310         kfree(mock);
311 }
312
313 static int mock_domain_map_pages(struct iommu_domain *domain,
314                                  unsigned long iova, phys_addr_t paddr,
315                                  size_t pgsize, size_t pgcount, int prot,
316                                  gfp_t gfp, size_t *mapped)
317 {
318         struct mock_iommu_domain *mock =
319                 container_of(domain, struct mock_iommu_domain, domain);
320         unsigned long flags = MOCK_PFN_START_IOVA;
321         unsigned long start_iova = iova;
322
323         /*
324          * xarray does not reliably work with fault injection because it does a
325          * retry allocation, so put our own failure point.
326          */
327         if (iommufd_should_fail())
328                 return -ENOENT;
329
330         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
331         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
332         for (; pgcount; pgcount--) {
333                 size_t cur;
334
335                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
336                         void *old;
337
338                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
339                                 flags = MOCK_PFN_LAST_IOVA;
340                         old = xa_store(&mock->pfns, iova / MOCK_IO_PAGE_SIZE,
341                                        xa_mk_value((paddr / MOCK_IO_PAGE_SIZE) |
342                                                    flags),
343                                        gfp);
344                         if (xa_is_err(old)) {
345                                 for (; start_iova != iova;
346                                      start_iova += MOCK_IO_PAGE_SIZE)
347                                         xa_erase(&mock->pfns,
348                                                  start_iova /
349                                                          MOCK_IO_PAGE_SIZE);
350                                 return xa_err(old);
351                         }
352                         WARN_ON(old);
353                         iova += MOCK_IO_PAGE_SIZE;
354                         paddr += MOCK_IO_PAGE_SIZE;
355                         *mapped += MOCK_IO_PAGE_SIZE;
356                         flags = 0;
357                 }
358         }
359         return 0;
360 }
361
362 static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
363                                       unsigned long iova, size_t pgsize,
364                                       size_t pgcount,
365                                       struct iommu_iotlb_gather *iotlb_gather)
366 {
367         struct mock_iommu_domain *mock =
368                 container_of(domain, struct mock_iommu_domain, domain);
369         bool first = true;
370         size_t ret = 0;
371         void *ent;
372
373         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
374         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
375
376         for (; pgcount; pgcount--) {
377                 size_t cur;
378
379                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
380                         ent = xa_erase(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
381
382                         /*
383                          * iommufd generates unmaps that must be a strict
384                          * superset of the map's performend So every starting
385                          * IOVA should have been an iova passed to map, and the
386                          *
387                          * First IOVA must be present and have been a first IOVA
388                          * passed to map_pages
389                          */
390                         if (first) {
391                                 WARN_ON(ent && !(xa_to_value(ent) &
392                                                  MOCK_PFN_START_IOVA));
393                                 first = false;
394                         }
395                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
396                                 WARN_ON(ent && !(xa_to_value(ent) &
397                                                  MOCK_PFN_LAST_IOVA));
398
399                         iova += MOCK_IO_PAGE_SIZE;
400                         ret += MOCK_IO_PAGE_SIZE;
401                 }
402         }
403         return ret;
404 }
405
406 static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
407                                             dma_addr_t iova)
408 {
409         struct mock_iommu_domain *mock =
410                 container_of(domain, struct mock_iommu_domain, domain);
411         void *ent;
412
413         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
414         ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
415         WARN_ON(!ent);
416         return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE;
417 }
418
419 static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
420 {
421         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
422
423         switch (cap) {
424         case IOMMU_CAP_CACHE_COHERENCY:
425                 return true;
426         case IOMMU_CAP_DIRTY_TRACKING:
427                 return !(mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY);
428         default:
429                 break;
430         }
431
432         return false;
433 }
434
435 static struct iommu_device mock_iommu_device = {
436 };
437
438 static struct iommu_device *mock_probe_device(struct device *dev)
439 {
440         return &mock_iommu_device;
441 }
442
443 static const struct iommu_ops mock_ops = {
444         /*
445          * IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
446          * because it is zero.
447          */
448         .default_domain = &mock_blocking_domain,
449         .blocked_domain = &mock_blocking_domain,
450         .owner = THIS_MODULE,
451         .pgsize_bitmap = MOCK_IO_PAGE_SIZE,
452         .hw_info = mock_domain_hw_info,
453         .domain_alloc_paging = mock_domain_alloc_paging,
454         .domain_alloc_user = mock_domain_alloc_user,
455         .capable = mock_domain_capable,
456         .device_group = generic_device_group,
457         .probe_device = mock_probe_device,
458         .default_domain_ops =
459                 &(struct iommu_domain_ops){
460                         .free = mock_domain_free,
461                         .attach_dev = mock_domain_nop_attach,
462                         .map_pages = mock_domain_map_pages,
463                         .unmap_pages = mock_domain_unmap_pages,
464                         .iova_to_phys = mock_domain_iova_to_phys,
465                 },
466 };
467
468 static void mock_domain_free_nested(struct iommu_domain *domain)
469 {
470         struct mock_iommu_domain_nested *mock_nested =
471                 container_of(domain, struct mock_iommu_domain_nested, domain);
472
473         kfree(mock_nested);
474 }
475
476 static struct iommu_domain_ops domain_nested_ops = {
477         .free = mock_domain_free_nested,
478         .attach_dev = mock_domain_nop_attach,
479 };
480
481 static inline struct iommufd_hw_pagetable *
482 __get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, u32 hwpt_type)
483 {
484         struct iommufd_object *obj;
485
486         obj = iommufd_get_object(ucmd->ictx, mockpt_id, hwpt_type);
487         if (IS_ERR(obj))
488                 return ERR_CAST(obj);
489         return container_of(obj, struct iommufd_hw_pagetable, obj);
490 }
491
492 static inline struct iommufd_hw_pagetable *
493 get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
494                  struct mock_iommu_domain **mock)
495 {
496         struct iommufd_hw_pagetable *hwpt;
497
498         hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_PAGING);
499         if (IS_ERR(hwpt))
500                 return hwpt;
501         if (hwpt->domain->type != IOMMU_DOMAIN_UNMANAGED ||
502             hwpt->domain->ops != mock_ops.default_domain_ops) {
503                 iommufd_put_object(&hwpt->obj);
504                 return ERR_PTR(-EINVAL);
505         }
506         *mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
507         return hwpt;
508 }
509
510 static inline struct iommufd_hw_pagetable *
511 get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
512                         struct mock_iommu_domain_nested **mock_nested)
513 {
514         struct iommufd_hw_pagetable *hwpt;
515
516         hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_NESTED);
517         if (IS_ERR(hwpt))
518                 return hwpt;
519         if (hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
520             hwpt->domain->ops != &domain_nested_ops) {
521                 iommufd_put_object(&hwpt->obj);
522                 return ERR_PTR(-EINVAL);
523         }
524         *mock_nested = container_of(hwpt->domain,
525                                     struct mock_iommu_domain_nested, domain);
526         return hwpt;
527 }
528
529 struct mock_bus_type {
530         struct bus_type bus;
531         struct notifier_block nb;
532 };
533
534 static struct mock_bus_type iommufd_mock_bus_type = {
535         .bus = {
536                 .name = "iommufd_mock",
537         },
538 };
539
540 static atomic_t mock_dev_num;
541
542 static void mock_dev_release(struct device *dev)
543 {
544         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
545
546         atomic_dec(&mock_dev_num);
547         kfree(mdev);
548 }
549
550 static struct mock_dev *mock_dev_create(unsigned long dev_flags)
551 {
552         struct mock_dev *mdev;
553         int rc;
554
555         if (dev_flags & ~(MOCK_FLAGS_DEVICE_NO_DIRTY))
556                 return ERR_PTR(-EINVAL);
557
558         mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
559         if (!mdev)
560                 return ERR_PTR(-ENOMEM);
561
562         device_initialize(&mdev->dev);
563         mdev->flags = dev_flags;
564         mdev->dev.release = mock_dev_release;
565         mdev->dev.bus = &iommufd_mock_bus_type.bus;
566
567         rc = dev_set_name(&mdev->dev, "iommufd_mock%u",
568                           atomic_inc_return(&mock_dev_num));
569         if (rc)
570                 goto err_put;
571
572         rc = device_add(&mdev->dev);
573         if (rc)
574                 goto err_put;
575         return mdev;
576
577 err_put:
578         put_device(&mdev->dev);
579         return ERR_PTR(rc);
580 }
581
582 static void mock_dev_destroy(struct mock_dev *mdev)
583 {
584         device_unregister(&mdev->dev);
585 }
586
587 bool iommufd_selftest_is_mock_dev(struct device *dev)
588 {
589         return dev->release == mock_dev_release;
590 }
591
592 /* Create an hw_pagetable with the mock domain so we can test the domain ops */
593 static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
594                                     struct iommu_test_cmd *cmd)
595 {
596         struct iommufd_device *idev;
597         struct selftest_obj *sobj;
598         u32 pt_id = cmd->id;
599         u32 dev_flags = 0;
600         u32 idev_id;
601         int rc;
602
603         sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
604         if (IS_ERR(sobj))
605                 return PTR_ERR(sobj);
606
607         sobj->idev.ictx = ucmd->ictx;
608         sobj->type = TYPE_IDEV;
609
610         if (cmd->op == IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS)
611                 dev_flags = cmd->mock_domain_flags.dev_flags;
612
613         sobj->idev.mock_dev = mock_dev_create(dev_flags);
614         if (IS_ERR(sobj->idev.mock_dev)) {
615                 rc = PTR_ERR(sobj->idev.mock_dev);
616                 goto out_sobj;
617         }
618
619         idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
620                                    &idev_id);
621         if (IS_ERR(idev)) {
622                 rc = PTR_ERR(idev);
623                 goto out_mdev;
624         }
625         sobj->idev.idev = idev;
626
627         rc = iommufd_device_attach(idev, &pt_id);
628         if (rc)
629                 goto out_unbind;
630
631         /* Userspace must destroy the device_id to destroy the object */
632         cmd->mock_domain.out_hwpt_id = pt_id;
633         cmd->mock_domain.out_stdev_id = sobj->obj.id;
634         cmd->mock_domain.out_idev_id = idev_id;
635         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
636         if (rc)
637                 goto out_detach;
638         iommufd_object_finalize(ucmd->ictx, &sobj->obj);
639         return 0;
640
641 out_detach:
642         iommufd_device_detach(idev);
643 out_unbind:
644         iommufd_device_unbind(idev);
645 out_mdev:
646         mock_dev_destroy(sobj->idev.mock_dev);
647 out_sobj:
648         iommufd_object_abort(ucmd->ictx, &sobj->obj);
649         return rc;
650 }
651
652 /* Replace the mock domain with a manually allocated hw_pagetable */
653 static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
654                                             unsigned int device_id, u32 pt_id,
655                                             struct iommu_test_cmd *cmd)
656 {
657         struct iommufd_object *dev_obj;
658         struct selftest_obj *sobj;
659         int rc;
660
661         /*
662          * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
663          * it doesn't race with detach, which is not allowed.
664          */
665         dev_obj =
666                 iommufd_get_object(ucmd->ictx, device_id, IOMMUFD_OBJ_SELFTEST);
667         if (IS_ERR(dev_obj))
668                 return PTR_ERR(dev_obj);
669
670         sobj = container_of(dev_obj, struct selftest_obj, obj);
671         if (sobj->type != TYPE_IDEV) {
672                 rc = -EINVAL;
673                 goto out_dev_obj;
674         }
675
676         rc = iommufd_device_replace(sobj->idev.idev, &pt_id);
677         if (rc)
678                 goto out_dev_obj;
679
680         cmd->mock_domain_replace.pt_id = pt_id;
681         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
682
683 out_dev_obj:
684         iommufd_put_object(dev_obj);
685         return rc;
686 }
687
688 /* Add an additional reserved IOVA to the IOAS */
689 static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
690                                      unsigned int mockpt_id,
691                                      unsigned long start, size_t length)
692 {
693         struct iommufd_ioas *ioas;
694         int rc;
695
696         ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
697         if (IS_ERR(ioas))
698                 return PTR_ERR(ioas);
699         down_write(&ioas->iopt.iova_rwsem);
700         rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
701         up_write(&ioas->iopt.iova_rwsem);
702         iommufd_put_object(&ioas->obj);
703         return rc;
704 }
705
706 /* Check that every pfn under each iova matches the pfn under a user VA */
707 static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
708                                     unsigned int mockpt_id, unsigned long iova,
709                                     size_t length, void __user *uptr)
710 {
711         struct iommufd_hw_pagetable *hwpt;
712         struct mock_iommu_domain *mock;
713         uintptr_t end;
714         int rc;
715
716         if (iova % MOCK_IO_PAGE_SIZE || length % MOCK_IO_PAGE_SIZE ||
717             (uintptr_t)uptr % MOCK_IO_PAGE_SIZE ||
718             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
719                 return -EINVAL;
720
721         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
722         if (IS_ERR(hwpt))
723                 return PTR_ERR(hwpt);
724
725         for (; length; length -= MOCK_IO_PAGE_SIZE) {
726                 struct page *pages[1];
727                 unsigned long pfn;
728                 long npages;
729                 void *ent;
730
731                 npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
732                                              pages);
733                 if (npages < 0) {
734                         rc = npages;
735                         goto out_put;
736                 }
737                 if (WARN_ON(npages != 1)) {
738                         rc = -EFAULT;
739                         goto out_put;
740                 }
741                 pfn = page_to_pfn(pages[0]);
742                 put_page(pages[0]);
743
744                 ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
745                 if (!ent ||
746                     (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE !=
747                             pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
748                         rc = -EINVAL;
749                         goto out_put;
750                 }
751                 iova += MOCK_IO_PAGE_SIZE;
752                 uptr += MOCK_IO_PAGE_SIZE;
753         }
754         rc = 0;
755
756 out_put:
757         iommufd_put_object(&hwpt->obj);
758         return rc;
759 }
760
761 /* Check that the page ref count matches, to look for missing pin/unpins */
762 static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
763                                       void __user *uptr, size_t length,
764                                       unsigned int refs)
765 {
766         uintptr_t end;
767
768         if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
769             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
770                 return -EINVAL;
771
772         for (; length; length -= PAGE_SIZE) {
773                 struct page *pages[1];
774                 long npages;
775
776                 npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
777                 if (npages < 0)
778                         return npages;
779                 if (WARN_ON(npages != 1))
780                         return -EFAULT;
781                 if (!PageCompound(pages[0])) {
782                         unsigned int count;
783
784                         count = page_ref_count(pages[0]);
785                         if (count / GUP_PIN_COUNTING_BIAS != refs) {
786                                 put_page(pages[0]);
787                                 return -EIO;
788                         }
789                 }
790                 put_page(pages[0]);
791                 uptr += PAGE_SIZE;
792         }
793         return 0;
794 }
795
796 struct selftest_access {
797         struct iommufd_access *access;
798         struct file *file;
799         struct mutex lock;
800         struct list_head items;
801         unsigned int next_id;
802         bool destroying;
803 };
804
805 struct selftest_access_item {
806         struct list_head items_elm;
807         unsigned long iova;
808         size_t length;
809         unsigned int id;
810 };
811
812 static const struct file_operations iommfd_test_staccess_fops;
813
814 static struct selftest_access *iommufd_access_get(int fd)
815 {
816         struct file *file;
817
818         file = fget(fd);
819         if (!file)
820                 return ERR_PTR(-EBADFD);
821
822         if (file->f_op != &iommfd_test_staccess_fops) {
823                 fput(file);
824                 return ERR_PTR(-EBADFD);
825         }
826         return file->private_data;
827 }
828
829 static void iommufd_test_access_unmap(void *data, unsigned long iova,
830                                       unsigned long length)
831 {
832         unsigned long iova_last = iova + length - 1;
833         struct selftest_access *staccess = data;
834         struct selftest_access_item *item;
835         struct selftest_access_item *tmp;
836
837         mutex_lock(&staccess->lock);
838         list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
839                 if (iova > item->iova + item->length - 1 ||
840                     iova_last < item->iova)
841                         continue;
842                 list_del(&item->items_elm);
843                 iommufd_access_unpin_pages(staccess->access, item->iova,
844                                            item->length);
845                 kfree(item);
846         }
847         mutex_unlock(&staccess->lock);
848 }
849
850 static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
851                                             unsigned int access_id,
852                                             unsigned int item_id)
853 {
854         struct selftest_access_item *item;
855         struct selftest_access *staccess;
856
857         staccess = iommufd_access_get(access_id);
858         if (IS_ERR(staccess))
859                 return PTR_ERR(staccess);
860
861         mutex_lock(&staccess->lock);
862         list_for_each_entry(item, &staccess->items, items_elm) {
863                 if (item->id == item_id) {
864                         list_del(&item->items_elm);
865                         iommufd_access_unpin_pages(staccess->access, item->iova,
866                                                    item->length);
867                         mutex_unlock(&staccess->lock);
868                         kfree(item);
869                         fput(staccess->file);
870                         return 0;
871                 }
872         }
873         mutex_unlock(&staccess->lock);
874         fput(staccess->file);
875         return -ENOENT;
876 }
877
878 static int iommufd_test_staccess_release(struct inode *inode,
879                                          struct file *filep)
880 {
881         struct selftest_access *staccess = filep->private_data;
882
883         if (staccess->access) {
884                 iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
885                 iommufd_access_destroy(staccess->access);
886         }
887         mutex_destroy(&staccess->lock);
888         kfree(staccess);
889         return 0;
890 }
891
892 static const struct iommufd_access_ops selftest_access_ops_pin = {
893         .needs_pin_pages = 1,
894         .unmap = iommufd_test_access_unmap,
895 };
896
897 static const struct iommufd_access_ops selftest_access_ops = {
898         .unmap = iommufd_test_access_unmap,
899 };
900
901 static const struct file_operations iommfd_test_staccess_fops = {
902         .release = iommufd_test_staccess_release,
903 };
904
905 static struct selftest_access *iommufd_test_alloc_access(void)
906 {
907         struct selftest_access *staccess;
908         struct file *filep;
909
910         staccess = kzalloc(sizeof(*staccess), GFP_KERNEL_ACCOUNT);
911         if (!staccess)
912                 return ERR_PTR(-ENOMEM);
913         INIT_LIST_HEAD(&staccess->items);
914         mutex_init(&staccess->lock);
915
916         filep = anon_inode_getfile("[iommufd_test_staccess]",
917                                    &iommfd_test_staccess_fops, staccess,
918                                    O_RDWR);
919         if (IS_ERR(filep)) {
920                 kfree(staccess);
921                 return ERR_CAST(filep);
922         }
923         staccess->file = filep;
924         return staccess;
925 }
926
927 static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
928                                       unsigned int ioas_id, unsigned int flags)
929 {
930         struct iommu_test_cmd *cmd = ucmd->cmd;
931         struct selftest_access *staccess;
932         struct iommufd_access *access;
933         u32 id;
934         int fdno;
935         int rc;
936
937         if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
938                 return -EOPNOTSUPP;
939
940         staccess = iommufd_test_alloc_access();
941         if (IS_ERR(staccess))
942                 return PTR_ERR(staccess);
943
944         fdno = get_unused_fd_flags(O_CLOEXEC);
945         if (fdno < 0) {
946                 rc = -ENOMEM;
947                 goto out_free_staccess;
948         }
949
950         access = iommufd_access_create(
951                 ucmd->ictx,
952                 (flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
953                         &selftest_access_ops_pin :
954                         &selftest_access_ops,
955                 staccess, &id);
956         if (IS_ERR(access)) {
957                 rc = PTR_ERR(access);
958                 goto out_put_fdno;
959         }
960         rc = iommufd_access_attach(access, ioas_id);
961         if (rc)
962                 goto out_destroy;
963         cmd->create_access.out_access_fd = fdno;
964         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
965         if (rc)
966                 goto out_destroy;
967
968         staccess->access = access;
969         fd_install(fdno, staccess->file);
970         return 0;
971
972 out_destroy:
973         iommufd_access_destroy(access);
974 out_put_fdno:
975         put_unused_fd(fdno);
976 out_free_staccess:
977         fput(staccess->file);
978         return rc;
979 }
980
981 static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
982                                             unsigned int access_id,
983                                             unsigned int ioas_id)
984 {
985         struct selftest_access *staccess;
986         int rc;
987
988         staccess = iommufd_access_get(access_id);
989         if (IS_ERR(staccess))
990                 return PTR_ERR(staccess);
991
992         rc = iommufd_access_replace(staccess->access, ioas_id);
993         fput(staccess->file);
994         return rc;
995 }
996
997 /* Check that the pages in a page array match the pages in the user VA */
998 static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
999                                     size_t npages)
1000 {
1001         for (; npages; npages--) {
1002                 struct page *tmp_pages[1];
1003                 long rc;
1004
1005                 rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
1006                 if (rc < 0)
1007                         return rc;
1008                 if (WARN_ON(rc != 1))
1009                         return -EFAULT;
1010                 put_page(tmp_pages[0]);
1011                 if (tmp_pages[0] != *pages)
1012                         return -EBADE;
1013                 pages++;
1014                 uptr += PAGE_SIZE;
1015         }
1016         return 0;
1017 }
1018
1019 static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
1020                                      unsigned int access_id, unsigned long iova,
1021                                      size_t length, void __user *uptr,
1022                                      u32 flags)
1023 {
1024         struct iommu_test_cmd *cmd = ucmd->cmd;
1025         struct selftest_access_item *item;
1026         struct selftest_access *staccess;
1027         struct page **pages;
1028         size_t npages;
1029         int rc;
1030
1031         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1032         if (length > 16*1024*1024)
1033                 return -ENOMEM;
1034
1035         if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
1036                 return -EOPNOTSUPP;
1037
1038         staccess = iommufd_access_get(access_id);
1039         if (IS_ERR(staccess))
1040                 return PTR_ERR(staccess);
1041
1042         if (staccess->access->ops != &selftest_access_ops_pin) {
1043                 rc = -EOPNOTSUPP;
1044                 goto out_put;
1045         }
1046
1047         if (flags & MOCK_FLAGS_ACCESS_SYZ)
1048                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
1049                                         &cmd->access_pages.iova);
1050
1051         npages = (ALIGN(iova + length, PAGE_SIZE) -
1052                   ALIGN_DOWN(iova, PAGE_SIZE)) /
1053                  PAGE_SIZE;
1054         pages = kvcalloc(npages, sizeof(*pages), GFP_KERNEL_ACCOUNT);
1055         if (!pages) {
1056                 rc = -ENOMEM;
1057                 goto out_put;
1058         }
1059
1060         /*
1061          * Drivers will need to think very carefully about this locking. The
1062          * core code can do multiple unmaps instantaneously after
1063          * iommufd_access_pin_pages() and *all* the unmaps must not return until
1064          * the range is unpinned. This simple implementation puts a global lock
1065          * around the pin, which may not suit drivers that want this to be a
1066          * performance path. drivers that get this wrong will trigger WARN_ON
1067          * races and cause EDEADLOCK failures to userspace.
1068          */
1069         mutex_lock(&staccess->lock);
1070         rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
1071                                       flags & MOCK_FLAGS_ACCESS_WRITE);
1072         if (rc)
1073                 goto out_unlock;
1074
1075         /* For syzkaller allow uptr to be NULL to skip this check */
1076         if (uptr) {
1077                 rc = iommufd_test_check_pages(
1078                         uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
1079                         npages);
1080                 if (rc)
1081                         goto out_unaccess;
1082         }
1083
1084         item = kzalloc(sizeof(*item), GFP_KERNEL_ACCOUNT);
1085         if (!item) {
1086                 rc = -ENOMEM;
1087                 goto out_unaccess;
1088         }
1089
1090         item->iova = iova;
1091         item->length = length;
1092         item->id = staccess->next_id++;
1093         list_add_tail(&item->items_elm, &staccess->items);
1094
1095         cmd->access_pages.out_access_pages_id = item->id;
1096         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1097         if (rc)
1098                 goto out_free_item;
1099         goto out_unlock;
1100
1101 out_free_item:
1102         list_del(&item->items_elm);
1103         kfree(item);
1104 out_unaccess:
1105         iommufd_access_unpin_pages(staccess->access, iova, length);
1106 out_unlock:
1107         mutex_unlock(&staccess->lock);
1108         kvfree(pages);
1109 out_put:
1110         fput(staccess->file);
1111         return rc;
1112 }
1113
1114 static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
1115                                   unsigned int access_id, unsigned long iova,
1116                                   size_t length, void __user *ubuf,
1117                                   unsigned int flags)
1118 {
1119         struct iommu_test_cmd *cmd = ucmd->cmd;
1120         struct selftest_access *staccess;
1121         void *tmp;
1122         int rc;
1123
1124         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1125         if (length > 16*1024*1024)
1126                 return -ENOMEM;
1127
1128         if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
1129                       MOCK_FLAGS_ACCESS_SYZ))
1130                 return -EOPNOTSUPP;
1131
1132         staccess = iommufd_access_get(access_id);
1133         if (IS_ERR(staccess))
1134                 return PTR_ERR(staccess);
1135
1136         tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
1137         if (!tmp) {
1138                 rc = -ENOMEM;
1139                 goto out_put;
1140         }
1141
1142         if (flags & MOCK_ACCESS_RW_WRITE) {
1143                 if (copy_from_user(tmp, ubuf, length)) {
1144                         rc = -EFAULT;
1145                         goto out_free;
1146                 }
1147         }
1148
1149         if (flags & MOCK_FLAGS_ACCESS_SYZ)
1150                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
1151                                         &cmd->access_rw.iova);
1152
1153         rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
1154         if (rc)
1155                 goto out_free;
1156         if (!(flags & MOCK_ACCESS_RW_WRITE)) {
1157                 if (copy_to_user(ubuf, tmp, length)) {
1158                         rc = -EFAULT;
1159                         goto out_free;
1160                 }
1161         }
1162
1163 out_free:
1164         kvfree(tmp);
1165 out_put:
1166         fput(staccess->file);
1167         return rc;
1168 }
1169 static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
1170 static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
1171               __IOMMUFD_ACCESS_RW_SLOW_PATH);
1172
1173 static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
1174                               unsigned long iova, size_t length,
1175                               unsigned long page_size, void __user *uptr,
1176                               u32 flags)
1177 {
1178         unsigned long bitmap_size, i, max;
1179         struct iommu_test_cmd *cmd = ucmd->cmd;
1180         struct iommufd_hw_pagetable *hwpt;
1181         struct mock_iommu_domain *mock;
1182         int rc, count = 0;
1183         void *tmp;
1184
1185         if (!page_size || !length || iova % page_size || length % page_size ||
1186             !uptr)
1187                 return -EINVAL;
1188
1189         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
1190         if (IS_ERR(hwpt))
1191                 return PTR_ERR(hwpt);
1192
1193         if (!(mock->flags & MOCK_DIRTY_TRACK)) {
1194                 rc = -EINVAL;
1195                 goto out_put;
1196         }
1197
1198         max = length / page_size;
1199         bitmap_size = max / BITS_PER_BYTE;
1200
1201         tmp = kvzalloc(bitmap_size, GFP_KERNEL_ACCOUNT);
1202         if (!tmp) {
1203                 rc = -ENOMEM;
1204                 goto out_put;
1205         }
1206
1207         if (copy_from_user(tmp, uptr, bitmap_size)) {
1208                 rc = -EFAULT;
1209                 goto out_free;
1210         }
1211
1212         for (i = 0; i < max; i++) {
1213                 unsigned long cur = iova + i * page_size;
1214                 void *ent, *old;
1215
1216                 if (!test_bit(i, (unsigned long *)tmp))
1217                         continue;
1218
1219                 ent = xa_load(&mock->pfns, cur / page_size);
1220                 if (ent) {
1221                         unsigned long val;
1222
1223                         val = xa_to_value(ent) | MOCK_PFN_DIRTY_IOVA;
1224                         old = xa_store(&mock->pfns, cur / page_size,
1225                                        xa_mk_value(val), GFP_KERNEL);
1226                         WARN_ON_ONCE(ent != old);
1227                         count++;
1228                 }
1229         }
1230
1231         cmd->dirty.out_nr_dirty = count;
1232         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1233 out_free:
1234         kvfree(tmp);
1235 out_put:
1236         iommufd_put_object(&hwpt->obj);
1237         return rc;
1238 }
1239
1240 void iommufd_selftest_destroy(struct iommufd_object *obj)
1241 {
1242         struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
1243
1244         switch (sobj->type) {
1245         case TYPE_IDEV:
1246                 iommufd_device_detach(sobj->idev.idev);
1247                 iommufd_device_unbind(sobj->idev.idev);
1248                 mock_dev_destroy(sobj->idev.mock_dev);
1249                 break;
1250         }
1251 }
1252
1253 int iommufd_test(struct iommufd_ucmd *ucmd)
1254 {
1255         struct iommu_test_cmd *cmd = ucmd->cmd;
1256
1257         switch (cmd->op) {
1258         case IOMMU_TEST_OP_ADD_RESERVED:
1259                 return iommufd_test_add_reserved(ucmd, cmd->id,
1260                                                  cmd->add_reserved.start,
1261                                                  cmd->add_reserved.length);
1262         case IOMMU_TEST_OP_MOCK_DOMAIN:
1263         case IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS:
1264                 return iommufd_test_mock_domain(ucmd, cmd);
1265         case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
1266                 return iommufd_test_mock_domain_replace(
1267                         ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
1268         case IOMMU_TEST_OP_MD_CHECK_MAP:
1269                 return iommufd_test_md_check_pa(
1270                         ucmd, cmd->id, cmd->check_map.iova,
1271                         cmd->check_map.length,
1272                         u64_to_user_ptr(cmd->check_map.uptr));
1273         case IOMMU_TEST_OP_MD_CHECK_REFS:
1274                 return iommufd_test_md_check_refs(
1275                         ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
1276                         cmd->check_refs.length, cmd->check_refs.refs);
1277         case IOMMU_TEST_OP_CREATE_ACCESS:
1278                 return iommufd_test_create_access(ucmd, cmd->id,
1279                                                   cmd->create_access.flags);
1280         case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
1281                 return iommufd_test_access_replace_ioas(
1282                         ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
1283         case IOMMU_TEST_OP_ACCESS_PAGES:
1284                 return iommufd_test_access_pages(
1285                         ucmd, cmd->id, cmd->access_pages.iova,
1286                         cmd->access_pages.length,
1287                         u64_to_user_ptr(cmd->access_pages.uptr),
1288                         cmd->access_pages.flags);
1289         case IOMMU_TEST_OP_ACCESS_RW:
1290                 return iommufd_test_access_rw(
1291                         ucmd, cmd->id, cmd->access_rw.iova,
1292                         cmd->access_rw.length,
1293                         u64_to_user_ptr(cmd->access_rw.uptr),
1294                         cmd->access_rw.flags);
1295         case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
1296                 return iommufd_test_access_item_destroy(
1297                         ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
1298         case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
1299                 /* Protect _batch_init(), can not be less than elmsz */
1300                 if (cmd->memory_limit.limit <
1301                     sizeof(unsigned long) + sizeof(u32))
1302                         return -EINVAL;
1303                 iommufd_test_memory_limit = cmd->memory_limit.limit;
1304                 return 0;
1305         case IOMMU_TEST_OP_DIRTY:
1306                 return iommufd_test_dirty(ucmd, cmd->id, cmd->dirty.iova,
1307                                           cmd->dirty.length,
1308                                           cmd->dirty.page_size,
1309                                           u64_to_user_ptr(cmd->dirty.uptr),
1310                                           cmd->dirty.flags);
1311         default:
1312                 return -EOPNOTSUPP;
1313         }
1314 }
1315
1316 bool iommufd_should_fail(void)
1317 {
1318         return should_fail(&fail_iommufd, 1);
1319 }
1320
1321 int __init iommufd_test_init(void)
1322 {
1323         struct platform_device_info pdevinfo = {
1324                 .name = "iommufd_selftest_iommu",
1325         };
1326         int rc;
1327
1328         dbgfs_root =
1329                 fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);
1330
1331         selftest_iommu_dev = platform_device_register_full(&pdevinfo);
1332         if (IS_ERR(selftest_iommu_dev)) {
1333                 rc = PTR_ERR(selftest_iommu_dev);
1334                 goto err_dbgfs;
1335         }
1336
1337         rc = bus_register(&iommufd_mock_bus_type.bus);
1338         if (rc)
1339                 goto err_platform;
1340
1341         rc = iommu_device_sysfs_add(&mock_iommu_device,
1342                                     &selftest_iommu_dev->dev, NULL, "%s",
1343                                     dev_name(&selftest_iommu_dev->dev));
1344         if (rc)
1345                 goto err_bus;
1346
1347         rc = iommu_device_register_bus(&mock_iommu_device, &mock_ops,
1348                                   &iommufd_mock_bus_type.bus,
1349                                   &iommufd_mock_bus_type.nb);
1350         if (rc)
1351                 goto err_sysfs;
1352         return 0;
1353
1354 err_sysfs:
1355         iommu_device_sysfs_remove(&mock_iommu_device);
1356 err_bus:
1357         bus_unregister(&iommufd_mock_bus_type.bus);
1358 err_platform:
1359         platform_device_unregister(selftest_iommu_dev);
1360 err_dbgfs:
1361         debugfs_remove_recursive(dbgfs_root);
1362         return rc;
1363 }
1364
1365 void iommufd_test_exit(void)
1366 {
1367         iommu_device_sysfs_remove(&mock_iommu_device);
1368         iommu_device_unregister_bus(&mock_iommu_device,
1369                                     &iommufd_mock_bus_type.bus,
1370                                     &iommufd_mock_bus_type.nb);
1371         bus_unregister(&iommufd_mock_bus_type.bus);
1372         platform_device_unregister(selftest_iommu_dev);
1373         debugfs_remove_recursive(dbgfs_root);
1374 }