Merge git://git.kernel.org/pub/scm/linux/kernel/git/pablo/nf
[sfrench/cifs-2.6.git] / drivers / vfio / vfio_iommu_type1.c
1 /*
2  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
3  *
4  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
5  *     Author: Alex Williamson <alex.williamson@redhat.com>
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License version 2 as
9  * published by the Free Software Foundation.
10  *
11  * Derived from original vfio:
12  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
13  * Author: Tom Lyon, pugs@cisco.com
14  *
15  * We arbitrarily define a Type1 IOMMU as one matching the below code.
16  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
17  * VT-d, but that makes it harder to re-use as theoretically anyone
18  * implementing a similar IOMMU could make use of this.  We expect the
19  * IOMMU to support the IOMMU API and have few to no restrictions around
20  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
21  * optimized for relatively static mappings of a userspace process with
22  * userpsace pages pinned into memory.  We also assume devices and IOMMU
23  * domains are PCI based as the IOMMU API is still centered around a
24  * device/bus interface rather than a group interface.
25  */
26
27 #include <linux/compat.h>
28 #include <linux/device.h>
29 #include <linux/fs.h>
30 #include <linux/iommu.h>
31 #include <linux/module.h>
32 #include <linux/mm.h>
33 #include <linux/rbtree.h>
34 #include <linux/sched/signal.h>
35 #include <linux/sched/mm.h>
36 #include <linux/slab.h>
37 #include <linux/uaccess.h>
38 #include <linux/vfio.h>
39 #include <linux/workqueue.h>
40 #include <linux/mdev.h>
41 #include <linux/notifier.h>
42 #include <linux/dma-iommu.h>
43 #include <linux/irqdomain.h>
44
45 #define DRIVER_VERSION  "0.2"
46 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
47 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
48
49 static bool allow_unsafe_interrupts;
50 module_param_named(allow_unsafe_interrupts,
51                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
52 MODULE_PARM_DESC(allow_unsafe_interrupts,
53                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
54
55 static bool disable_hugepages;
56 module_param_named(disable_hugepages,
57                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
58 MODULE_PARM_DESC(disable_hugepages,
59                  "Disable VFIO IOMMU support for IOMMU hugepages.");
60
61 struct vfio_iommu {
62         struct list_head        domain_list;
63         struct vfio_domain      *external_domain; /* domain for external user */
64         struct mutex            lock;
65         struct rb_root          dma_list;
66         struct blocking_notifier_head notifier;
67         bool                    v2;
68         bool                    nesting;
69 };
70
71 struct vfio_domain {
72         struct iommu_domain     *domain;
73         struct list_head        next;
74         struct list_head        group_list;
75         int                     prot;           /* IOMMU_CACHE */
76         bool                    fgsp;           /* Fine-grained super pages */
77 };
78
79 struct vfio_dma {
80         struct rb_node          node;
81         dma_addr_t              iova;           /* Device address */
82         unsigned long           vaddr;          /* Process virtual addr */
83         size_t                  size;           /* Map size (bytes) */
84         int                     prot;           /* IOMMU_READ/WRITE */
85         bool                    iommu_mapped;
86         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
87         struct task_struct      *task;
88         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
89 };
90
91 struct vfio_group {
92         struct iommu_group      *iommu_group;
93         struct list_head        next;
94 };
95
96 /*
97  * Guest RAM pinning working set or DMA target
98  */
99 struct vfio_pfn {
100         struct rb_node          node;
101         dma_addr_t              iova;           /* Device address */
102         unsigned long           pfn;            /* Host pfn */
103         atomic_t                ref_count;
104 };
105
106 struct vfio_regions {
107         struct list_head list;
108         dma_addr_t iova;
109         phys_addr_t phys;
110         size_t len;
111 };
112
113 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
114                                         (!list_empty(&iommu->domain_list))
115
116 static int put_pfn(unsigned long pfn, int prot);
117
118 /*
119  * This code handles mapping and unmapping of user data buffers
120  * into DMA'ble space using the IOMMU
121  */
122
123 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
124                                       dma_addr_t start, size_t size)
125 {
126         struct rb_node *node = iommu->dma_list.rb_node;
127
128         while (node) {
129                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
130
131                 if (start + size <= dma->iova)
132                         node = node->rb_left;
133                 else if (start >= dma->iova + dma->size)
134                         node = node->rb_right;
135                 else
136                         return dma;
137         }
138
139         return NULL;
140 }
141
142 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
143 {
144         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
145         struct vfio_dma *dma;
146
147         while (*link) {
148                 parent = *link;
149                 dma = rb_entry(parent, struct vfio_dma, node);
150
151                 if (new->iova + new->size <= dma->iova)
152                         link = &(*link)->rb_left;
153                 else
154                         link = &(*link)->rb_right;
155         }
156
157         rb_link_node(&new->node, parent, link);
158         rb_insert_color(&new->node, &iommu->dma_list);
159 }
160
161 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
162 {
163         rb_erase(&old->node, &iommu->dma_list);
164 }
165
166 /*
167  * Helper Functions for host iova-pfn list
168  */
169 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
170 {
171         struct vfio_pfn *vpfn;
172         struct rb_node *node = dma->pfn_list.rb_node;
173
174         while (node) {
175                 vpfn = rb_entry(node, struct vfio_pfn, node);
176
177                 if (iova < vpfn->iova)
178                         node = node->rb_left;
179                 else if (iova > vpfn->iova)
180                         node = node->rb_right;
181                 else
182                         return vpfn;
183         }
184         return NULL;
185 }
186
187 static void vfio_link_pfn(struct vfio_dma *dma,
188                           struct vfio_pfn *new)
189 {
190         struct rb_node **link, *parent = NULL;
191         struct vfio_pfn *vpfn;
192
193         link = &dma->pfn_list.rb_node;
194         while (*link) {
195                 parent = *link;
196                 vpfn = rb_entry(parent, struct vfio_pfn, node);
197
198                 if (new->iova < vpfn->iova)
199                         link = &(*link)->rb_left;
200                 else
201                         link = &(*link)->rb_right;
202         }
203
204         rb_link_node(&new->node, parent, link);
205         rb_insert_color(&new->node, &dma->pfn_list);
206 }
207
208 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
209 {
210         rb_erase(&old->node, &dma->pfn_list);
211 }
212
213 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
214                                 unsigned long pfn)
215 {
216         struct vfio_pfn *vpfn;
217
218         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
219         if (!vpfn)
220                 return -ENOMEM;
221
222         vpfn->iova = iova;
223         vpfn->pfn = pfn;
224         atomic_set(&vpfn->ref_count, 1);
225         vfio_link_pfn(dma, vpfn);
226         return 0;
227 }
228
229 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
230                                       struct vfio_pfn *vpfn)
231 {
232         vfio_unlink_pfn(dma, vpfn);
233         kfree(vpfn);
234 }
235
236 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
237                                                unsigned long iova)
238 {
239         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
240
241         if (vpfn)
242                 atomic_inc(&vpfn->ref_count);
243         return vpfn;
244 }
245
246 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
247 {
248         int ret = 0;
249
250         if (atomic_dec_and_test(&vpfn->ref_count)) {
251                 ret = put_pfn(vpfn->pfn, dma->prot);
252                 vfio_remove_from_pfn_list(dma, vpfn);
253         }
254         return ret;
255 }
256
257 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
258 {
259         struct mm_struct *mm;
260         int ret;
261
262         if (!npage)
263                 return 0;
264
265         mm = async ? get_task_mm(dma->task) : dma->task->mm;
266         if (!mm)
267                 return -ESRCH; /* process exited */
268
269         ret = down_write_killable(&mm->mmap_sem);
270         if (!ret) {
271                 if (npage > 0) {
272                         if (!dma->lock_cap) {
273                                 unsigned long limit;
274
275                                 limit = task_rlimit(dma->task,
276                                                 RLIMIT_MEMLOCK) >> PAGE_SHIFT;
277
278                                 if (mm->locked_vm + npage > limit)
279                                         ret = -ENOMEM;
280                         }
281                 }
282
283                 if (!ret)
284                         mm->locked_vm += npage;
285
286                 up_write(&mm->mmap_sem);
287         }
288
289         if (async)
290                 mmput(mm);
291
292         return ret;
293 }
294
295 /*
296  * Some mappings aren't backed by a struct page, for example an mmap'd
297  * MMIO range for our own or another device.  These use a different
298  * pfn conversion and shouldn't be tracked as locked pages.
299  */
300 static bool is_invalid_reserved_pfn(unsigned long pfn)
301 {
302         if (pfn_valid(pfn)) {
303                 bool reserved;
304                 struct page *tail = pfn_to_page(pfn);
305                 struct page *head = compound_head(tail);
306                 reserved = !!(PageReserved(head));
307                 if (head != tail) {
308                         /*
309                          * "head" is not a dangling pointer
310                          * (compound_head takes care of that)
311                          * but the hugepage may have been split
312                          * from under us (and we may not hold a
313                          * reference count on the head page so it can
314                          * be reused before we run PageReferenced), so
315                          * we've to check PageTail before returning
316                          * what we just read.
317                          */
318                         smp_rmb();
319                         if (PageTail(tail))
320                                 return reserved;
321                 }
322                 return PageReserved(tail);
323         }
324
325         return true;
326 }
327
328 static int put_pfn(unsigned long pfn, int prot)
329 {
330         if (!is_invalid_reserved_pfn(pfn)) {
331                 struct page *page = pfn_to_page(pfn);
332                 if (prot & IOMMU_WRITE)
333                         SetPageDirty(page);
334                 put_page(page);
335                 return 1;
336         }
337         return 0;
338 }
339
340 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
341                          int prot, unsigned long *pfn)
342 {
343         struct page *page[1];
344         struct vm_area_struct *vma;
345         struct vm_area_struct *vmas[1];
346         unsigned int flags = 0;
347         int ret;
348
349         if (prot & IOMMU_WRITE)
350                 flags |= FOLL_WRITE;
351
352         down_read(&mm->mmap_sem);
353         if (mm == current->mm) {
354                 ret = get_user_pages_longterm(vaddr, 1, flags, page, vmas);
355         } else {
356                 ret = get_user_pages_remote(NULL, mm, vaddr, 1, flags, page,
357                                             vmas, NULL);
358                 /*
359                  * The lifetime of a vaddr_get_pfn() page pin is
360                  * userspace-controlled. In the fs-dax case this could
361                  * lead to indefinite stalls in filesystem operations.
362                  * Disallow attempts to pin fs-dax pages via this
363                  * interface.
364                  */
365                 if (ret > 0 && vma_is_fsdax(vmas[0])) {
366                         ret = -EOPNOTSUPP;
367                         put_page(page[0]);
368                 }
369         }
370         up_read(&mm->mmap_sem);
371
372         if (ret == 1) {
373                 *pfn = page_to_pfn(page[0]);
374                 return 0;
375         }
376
377         down_read(&mm->mmap_sem);
378
379         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
380
381         if (vma && vma->vm_flags & VM_PFNMAP) {
382                 *pfn = ((vaddr - vma->vm_start) >> PAGE_SHIFT) + vma->vm_pgoff;
383                 if (is_invalid_reserved_pfn(*pfn))
384                         ret = 0;
385         }
386
387         up_read(&mm->mmap_sem);
388         return ret;
389 }
390
391 /*
392  * Attempt to pin pages.  We really don't want to track all the pfns and
393  * the iommu can only map chunks of consecutive pfns anyway, so get the
394  * first page and all consecutive pages with the same locking.
395  */
396 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
397                                   long npage, unsigned long *pfn_base,
398                                   unsigned long limit)
399 {
400         unsigned long pfn = 0;
401         long ret, pinned = 0, lock_acct = 0;
402         bool rsvd;
403         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
404
405         /* This code path is only user initiated */
406         if (!current->mm)
407                 return -ENODEV;
408
409         ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
410         if (ret)
411                 return ret;
412
413         pinned++;
414         rsvd = is_invalid_reserved_pfn(*pfn_base);
415
416         /*
417          * Reserved pages aren't counted against the user, externally pinned
418          * pages are already counted against the user.
419          */
420         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
421                 if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
422                         put_pfn(*pfn_base, dma->prot);
423                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
424                                         limit << PAGE_SHIFT);
425                         return -ENOMEM;
426                 }
427                 lock_acct++;
428         }
429
430         if (unlikely(disable_hugepages))
431                 goto out;
432
433         /* Lock all the consecutive pages from pfn_base */
434         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
435              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
436                 ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
437                 if (ret)
438                         break;
439
440                 if (pfn != *pfn_base + pinned ||
441                     rsvd != is_invalid_reserved_pfn(pfn)) {
442                         put_pfn(pfn, dma->prot);
443                         break;
444                 }
445
446                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
447                         if (!dma->lock_cap &&
448                             current->mm->locked_vm + lock_acct + 1 > limit) {
449                                 put_pfn(pfn, dma->prot);
450                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
451                                         __func__, limit << PAGE_SHIFT);
452                                 ret = -ENOMEM;
453                                 goto unpin_out;
454                         }
455                         lock_acct++;
456                 }
457         }
458
459 out:
460         ret = vfio_lock_acct(dma, lock_acct, false);
461
462 unpin_out:
463         if (ret) {
464                 if (!rsvd) {
465                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
466                                 put_pfn(pfn, dma->prot);
467                 }
468
469                 return ret;
470         }
471
472         return pinned;
473 }
474
475 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
476                                     unsigned long pfn, long npage,
477                                     bool do_accounting)
478 {
479         long unlocked = 0, locked = 0;
480         long i;
481
482         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
483                 if (put_pfn(pfn++, dma->prot)) {
484                         unlocked++;
485                         if (vfio_find_vpfn(dma, iova))
486                                 locked++;
487                 }
488         }
489
490         if (do_accounting)
491                 vfio_lock_acct(dma, locked - unlocked, true);
492
493         return unlocked;
494 }
495
496 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
497                                   unsigned long *pfn_base, bool do_accounting)
498 {
499         struct mm_struct *mm;
500         int ret;
501
502         mm = get_task_mm(dma->task);
503         if (!mm)
504                 return -ENODEV;
505
506         ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
507         if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
508                 ret = vfio_lock_acct(dma, 1, true);
509                 if (ret) {
510                         put_pfn(*pfn_base, dma->prot);
511                         if (ret == -ENOMEM)
512                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
513                                         "(%ld) exceeded\n", __func__,
514                                         dma->task->comm, task_pid_nr(dma->task),
515                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
516                 }
517         }
518
519         mmput(mm);
520         return ret;
521 }
522
523 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
524                                     bool do_accounting)
525 {
526         int unlocked;
527         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
528
529         if (!vpfn)
530                 return 0;
531
532         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
533
534         if (do_accounting)
535                 vfio_lock_acct(dma, -unlocked, true);
536
537         return unlocked;
538 }
539
540 static int vfio_iommu_type1_pin_pages(void *iommu_data,
541                                       unsigned long *user_pfn,
542                                       int npage, int prot,
543                                       unsigned long *phys_pfn)
544 {
545         struct vfio_iommu *iommu = iommu_data;
546         int i, j, ret;
547         unsigned long remote_vaddr;
548         struct vfio_dma *dma;
549         bool do_accounting;
550
551         if (!iommu || !user_pfn || !phys_pfn)
552                 return -EINVAL;
553
554         /* Supported for v2 version only */
555         if (!iommu->v2)
556                 return -EACCES;
557
558         mutex_lock(&iommu->lock);
559
560         /* Fail if notifier list is empty */
561         if ((!iommu->external_domain) || (!iommu->notifier.head)) {
562                 ret = -EINVAL;
563                 goto pin_done;
564         }
565
566         /*
567          * If iommu capable domain exist in the container then all pages are
568          * already pinned and accounted. Accouting should be done if there is no
569          * iommu capable domain in the container.
570          */
571         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
572
573         for (i = 0; i < npage; i++) {
574                 dma_addr_t iova;
575                 struct vfio_pfn *vpfn;
576
577                 iova = user_pfn[i] << PAGE_SHIFT;
578                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
579                 if (!dma) {
580                         ret = -EINVAL;
581                         goto pin_unwind;
582                 }
583
584                 if ((dma->prot & prot) != prot) {
585                         ret = -EPERM;
586                         goto pin_unwind;
587                 }
588
589                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
590                 if (vpfn) {
591                         phys_pfn[i] = vpfn->pfn;
592                         continue;
593                 }
594
595                 remote_vaddr = dma->vaddr + iova - dma->iova;
596                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
597                                              do_accounting);
598                 if (ret)
599                         goto pin_unwind;
600
601                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
602                 if (ret) {
603                         vfio_unpin_page_external(dma, iova, do_accounting);
604                         goto pin_unwind;
605                 }
606         }
607
608         ret = i;
609         goto pin_done;
610
611 pin_unwind:
612         phys_pfn[i] = 0;
613         for (j = 0; j < i; j++) {
614                 dma_addr_t iova;
615
616                 iova = user_pfn[j] << PAGE_SHIFT;
617                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
618                 vfio_unpin_page_external(dma, iova, do_accounting);
619                 phys_pfn[j] = 0;
620         }
621 pin_done:
622         mutex_unlock(&iommu->lock);
623         return ret;
624 }
625
626 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
627                                         unsigned long *user_pfn,
628                                         int npage)
629 {
630         struct vfio_iommu *iommu = iommu_data;
631         bool do_accounting;
632         int i;
633
634         if (!iommu || !user_pfn)
635                 return -EINVAL;
636
637         /* Supported for v2 version only */
638         if (!iommu->v2)
639                 return -EACCES;
640
641         mutex_lock(&iommu->lock);
642
643         if (!iommu->external_domain) {
644                 mutex_unlock(&iommu->lock);
645                 return -EINVAL;
646         }
647
648         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
649         for (i = 0; i < npage; i++) {
650                 struct vfio_dma *dma;
651                 dma_addr_t iova;
652
653                 iova = user_pfn[i] << PAGE_SHIFT;
654                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
655                 if (!dma)
656                         goto unpin_exit;
657                 vfio_unpin_page_external(dma, iova, do_accounting);
658         }
659
660 unpin_exit:
661         mutex_unlock(&iommu->lock);
662         return i > npage ? npage : (i > 0 ? i : -EINVAL);
663 }
664
665 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
666                                 struct list_head *regions)
667 {
668         long unlocked = 0;
669         struct vfio_regions *entry, *next;
670
671         iommu_tlb_sync(domain->domain);
672
673         list_for_each_entry_safe(entry, next, regions, list) {
674                 unlocked += vfio_unpin_pages_remote(dma,
675                                                     entry->iova,
676                                                     entry->phys >> PAGE_SHIFT,
677                                                     entry->len >> PAGE_SHIFT,
678                                                     false);
679                 list_del(&entry->list);
680                 kfree(entry);
681         }
682
683         cond_resched();
684
685         return unlocked;
686 }
687
688 /*
689  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
690  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
691  * of these regions (currently using a list).
692  *
693  * This value specifies maximum number of regions for each IOTLB flush sync.
694  */
695 #define VFIO_IOMMU_TLB_SYNC_MAX         512
696
697 static size_t unmap_unpin_fast(struct vfio_domain *domain,
698                                struct vfio_dma *dma, dma_addr_t *iova,
699                                size_t len, phys_addr_t phys, long *unlocked,
700                                struct list_head *unmapped_list,
701                                int *unmapped_cnt)
702 {
703         size_t unmapped = 0;
704         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
705
706         if (entry) {
707                 unmapped = iommu_unmap_fast(domain->domain, *iova, len);
708
709                 if (!unmapped) {
710                         kfree(entry);
711                 } else {
712                         iommu_tlb_range_add(domain->domain, *iova, unmapped);
713                         entry->iova = *iova;
714                         entry->phys = phys;
715                         entry->len  = unmapped;
716                         list_add_tail(&entry->list, unmapped_list);
717
718                         *iova += unmapped;
719                         (*unmapped_cnt)++;
720                 }
721         }
722
723         /*
724          * Sync if the number of fast-unmap regions hits the limit
725          * or in case of errors.
726          */
727         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
728                 *unlocked += vfio_sync_unpin(dma, domain,
729                                              unmapped_list);
730                 *unmapped_cnt = 0;
731         }
732
733         return unmapped;
734 }
735
736 static size_t unmap_unpin_slow(struct vfio_domain *domain,
737                                struct vfio_dma *dma, dma_addr_t *iova,
738                                size_t len, phys_addr_t phys,
739                                long *unlocked)
740 {
741         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
742
743         if (unmapped) {
744                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
745                                                      phys >> PAGE_SHIFT,
746                                                      unmapped >> PAGE_SHIFT,
747                                                      false);
748                 *iova += unmapped;
749                 cond_resched();
750         }
751         return unmapped;
752 }
753
754 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
755                              bool do_accounting)
756 {
757         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
758         struct vfio_domain *domain, *d;
759         LIST_HEAD(unmapped_region_list);
760         int unmapped_region_cnt = 0;
761         long unlocked = 0;
762
763         if (!dma->size)
764                 return 0;
765
766         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
767                 return 0;
768
769         /*
770          * We use the IOMMU to track the physical addresses, otherwise we'd
771          * need a much more complicated tracking system.  Unfortunately that
772          * means we need to use one of the iommu domains to figure out the
773          * pfns to unpin.  The rest need to be unmapped in advance so we have
774          * no iommu translations remaining when the pages are unpinned.
775          */
776         domain = d = list_first_entry(&iommu->domain_list,
777                                       struct vfio_domain, next);
778
779         list_for_each_entry_continue(d, &iommu->domain_list, next) {
780                 iommu_unmap(d->domain, dma->iova, dma->size);
781                 cond_resched();
782         }
783
784         while (iova < end) {
785                 size_t unmapped, len;
786                 phys_addr_t phys, next;
787
788                 phys = iommu_iova_to_phys(domain->domain, iova);
789                 if (WARN_ON(!phys)) {
790                         iova += PAGE_SIZE;
791                         continue;
792                 }
793
794                 /*
795                  * To optimize for fewer iommu_unmap() calls, each of which
796                  * may require hardware cache flushing, try to find the
797                  * largest contiguous physical memory chunk to unmap.
798                  */
799                 for (len = PAGE_SIZE;
800                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
801                         next = iommu_iova_to_phys(domain->domain, iova + len);
802                         if (next != phys + len)
803                                 break;
804                 }
805
806                 /*
807                  * First, try to use fast unmap/unpin. In case of failure,
808                  * switch to slow unmap/unpin path.
809                  */
810                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
811                                             &unlocked, &unmapped_region_list,
812                                             &unmapped_region_cnt);
813                 if (!unmapped) {
814                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
815                                                     phys, &unlocked);
816                         if (WARN_ON(!unmapped))
817                                 break;
818                 }
819         }
820
821         dma->iommu_mapped = false;
822
823         if (unmapped_region_cnt)
824                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list);
825
826         if (do_accounting) {
827                 vfio_lock_acct(dma, -unlocked, true);
828                 return 0;
829         }
830         return unlocked;
831 }
832
833 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
834 {
835         vfio_unmap_unpin(iommu, dma, true);
836         vfio_unlink_dma(iommu, dma);
837         put_task_struct(dma->task);
838         kfree(dma);
839 }
840
841 static unsigned long vfio_pgsize_bitmap(struct vfio_iommu *iommu)
842 {
843         struct vfio_domain *domain;
844         unsigned long bitmap = ULONG_MAX;
845
846         mutex_lock(&iommu->lock);
847         list_for_each_entry(domain, &iommu->domain_list, next)
848                 bitmap &= domain->domain->pgsize_bitmap;
849         mutex_unlock(&iommu->lock);
850
851         /*
852          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
853          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
854          * That way the user will be able to map/unmap buffers whose size/
855          * start address is aligned with PAGE_SIZE. Pinning code uses that
856          * granularity while iommu driver can use the sub-PAGE_SIZE size
857          * to map the buffer.
858          */
859         if (bitmap & ~PAGE_MASK) {
860                 bitmap &= PAGE_MASK;
861                 bitmap |= PAGE_SIZE;
862         }
863
864         return bitmap;
865 }
866
867 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
868                              struct vfio_iommu_type1_dma_unmap *unmap)
869 {
870         uint64_t mask;
871         struct vfio_dma *dma, *dma_last = NULL;
872         size_t unmapped = 0;
873         int ret = 0, retries = 0;
874
875         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
876
877         if (unmap->iova & mask)
878                 return -EINVAL;
879         if (!unmap->size || unmap->size & mask)
880                 return -EINVAL;
881         if (unmap->iova + unmap->size < unmap->iova ||
882             unmap->size > SIZE_MAX)
883                 return -EINVAL;
884
885         WARN_ON(mask & PAGE_MASK);
886 again:
887         mutex_lock(&iommu->lock);
888
889         /*
890          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
891          * avoid tracking individual mappings.  This means that the granularity
892          * of the original mapping was lost and the user was allowed to attempt
893          * to unmap any range.  Depending on the contiguousness of physical
894          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
895          * or may not have worked.  We only guaranteed unmap granularity
896          * matching the original mapping; even though it was untracked here,
897          * the original mappings are reflected in IOMMU mappings.  This
898          * resulted in a couple unusual behaviors.  First, if a range is not
899          * able to be unmapped, ex. a set of 4k pages that was mapped as a
900          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
901          * a zero sized unmap.  Also, if an unmap request overlaps the first
902          * address of a hugepage, the IOMMU will unmap the entire hugepage.
903          * This also returns success and the returned unmap size reflects the
904          * actual size unmapped.
905          *
906          * We attempt to maintain compatibility with this "v1" interface, but
907          * we take control out of the hands of the IOMMU.  Therefore, an unmap
908          * request offset from the beginning of the original mapping will
909          * return success with zero sized unmap.  And an unmap request covering
910          * the first iova of mapping will unmap the entire range.
911          *
912          * The v2 version of this interface intends to be more deterministic.
913          * Unmap requests must fully cover previous mappings.  Multiple
914          * mappings may still be unmaped by specifying large ranges, but there
915          * must not be any previous mappings bisected by the range.  An error
916          * will be returned if these conditions are not met.  The v2 interface
917          * will only return success and a size of zero if there were no
918          * mappings within the range.
919          */
920         if (iommu->v2) {
921                 dma = vfio_find_dma(iommu, unmap->iova, 1);
922                 if (dma && dma->iova != unmap->iova) {
923                         ret = -EINVAL;
924                         goto unlock;
925                 }
926                 dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
927                 if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
928                         ret = -EINVAL;
929                         goto unlock;
930                 }
931         }
932
933         while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
934                 if (!iommu->v2 && unmap->iova > dma->iova)
935                         break;
936                 /*
937                  * Task with same address space who mapped this iova range is
938                  * allowed to unmap the iova range.
939                  */
940                 if (dma->task->mm != current->mm)
941                         break;
942
943                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
944                         struct vfio_iommu_type1_dma_unmap nb_unmap;
945
946                         if (dma_last == dma) {
947                                 BUG_ON(++retries > 10);
948                         } else {
949                                 dma_last = dma;
950                                 retries = 0;
951                         }
952
953                         nb_unmap.iova = dma->iova;
954                         nb_unmap.size = dma->size;
955
956                         /*
957                          * Notify anyone (mdev vendor drivers) to invalidate and
958                          * unmap iovas within the range we're about to unmap.
959                          * Vendor drivers MUST unpin pages in response to an
960                          * invalidation.
961                          */
962                         mutex_unlock(&iommu->lock);
963                         blocking_notifier_call_chain(&iommu->notifier,
964                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
965                                                     &nb_unmap);
966                         goto again;
967                 }
968                 unmapped += dma->size;
969                 vfio_remove_dma(iommu, dma);
970         }
971
972 unlock:
973         mutex_unlock(&iommu->lock);
974
975         /* Report how much was unmapped */
976         unmap->size = unmapped;
977
978         return ret;
979 }
980
981 /*
982  * Turns out AMD IOMMU has a page table bug where it won't map large pages
983  * to a region that previously mapped smaller pages.  This should be fixed
984  * soon, so this is just a temporary workaround to break mappings down into
985  * PAGE_SIZE.  Better to map smaller pages than nothing.
986  */
987 static int map_try_harder(struct vfio_domain *domain, dma_addr_t iova,
988                           unsigned long pfn, long npage, int prot)
989 {
990         long i;
991         int ret = 0;
992
993         for (i = 0; i < npage; i++, pfn++, iova += PAGE_SIZE) {
994                 ret = iommu_map(domain->domain, iova,
995                                 (phys_addr_t)pfn << PAGE_SHIFT,
996                                 PAGE_SIZE, prot | domain->prot);
997                 if (ret)
998                         break;
999         }
1000
1001         for (; i < npage && i > 0; i--, iova -= PAGE_SIZE)
1002                 iommu_unmap(domain->domain, iova, PAGE_SIZE);
1003
1004         return ret;
1005 }
1006
1007 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1008                           unsigned long pfn, long npage, int prot)
1009 {
1010         struct vfio_domain *d;
1011         int ret;
1012
1013         list_for_each_entry(d, &iommu->domain_list, next) {
1014                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1015                                 npage << PAGE_SHIFT, prot | d->prot);
1016                 if (ret) {
1017                         if (ret != -EBUSY ||
1018                             map_try_harder(d, iova, pfn, npage, prot))
1019                                 goto unwind;
1020                 }
1021
1022                 cond_resched();
1023         }
1024
1025         return 0;
1026
1027 unwind:
1028         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next)
1029                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1030
1031         return ret;
1032 }
1033
1034 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1035                             size_t map_size)
1036 {
1037         dma_addr_t iova = dma->iova;
1038         unsigned long vaddr = dma->vaddr;
1039         size_t size = map_size;
1040         long npage;
1041         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1042         int ret = 0;
1043
1044         while (size) {
1045                 /* Pin a contiguous chunk of memory */
1046                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1047                                               size >> PAGE_SHIFT, &pfn, limit);
1048                 if (npage <= 0) {
1049                         WARN_ON(!npage);
1050                         ret = (int)npage;
1051                         break;
1052                 }
1053
1054                 /* Map it! */
1055                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1056                                      dma->prot);
1057                 if (ret) {
1058                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1059                                                 npage, true);
1060                         break;
1061                 }
1062
1063                 size -= npage << PAGE_SHIFT;
1064                 dma->size += npage << PAGE_SHIFT;
1065         }
1066
1067         dma->iommu_mapped = true;
1068
1069         if (ret)
1070                 vfio_remove_dma(iommu, dma);
1071
1072         return ret;
1073 }
1074
1075 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1076                            struct vfio_iommu_type1_dma_map *map)
1077 {
1078         dma_addr_t iova = map->iova;
1079         unsigned long vaddr = map->vaddr;
1080         size_t size = map->size;
1081         int ret = 0, prot = 0;
1082         uint64_t mask;
1083         struct vfio_dma *dma;
1084
1085         /* Verify that none of our __u64 fields overflow */
1086         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1087                 return -EINVAL;
1088
1089         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
1090
1091         WARN_ON(mask & PAGE_MASK);
1092
1093         /* READ/WRITE from device perspective */
1094         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1095                 prot |= IOMMU_WRITE;
1096         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1097                 prot |= IOMMU_READ;
1098
1099         if (!prot || !size || (size | iova | vaddr) & mask)
1100                 return -EINVAL;
1101
1102         /* Don't allow IOVA or virtual address wrap */
1103         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr)
1104                 return -EINVAL;
1105
1106         mutex_lock(&iommu->lock);
1107
1108         if (vfio_find_dma(iommu, iova, size)) {
1109                 ret = -EEXIST;
1110                 goto out_unlock;
1111         }
1112
1113         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1114         if (!dma) {
1115                 ret = -ENOMEM;
1116                 goto out_unlock;
1117         }
1118
1119         dma->iova = iova;
1120         dma->vaddr = vaddr;
1121         dma->prot = prot;
1122
1123         /*
1124          * We need to be able to both add to a task's locked memory and test
1125          * against the locked memory limit and we need to be able to do both
1126          * outside of this call path as pinning can be asynchronous via the
1127          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1128          * task_struct and VM locked pages requires an mm_struct, however
1129          * holding an indefinite mm reference is not recommended, therefore we
1130          * only hold a reference to a task.  We could hold a reference to
1131          * current, however QEMU uses this call path through vCPU threads,
1132          * which can be killed resulting in a NULL mm and failure in the unmap
1133          * path when called via a different thread.  Avoid this problem by
1134          * using the group_leader as threads within the same group require
1135          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1136          * mm_struct.
1137          *
1138          * Previously we also used the task for testing CAP_IPC_LOCK at the
1139          * time of pinning and accounting, however has_capability() makes use
1140          * of real_cred, a copy-on-write field, so we can't guarantee that it
1141          * matches group_leader, or in fact that it might not change by the
1142          * time it's evaluated.  If a process were to call MAP_DMA with
1143          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1144          * possibly see different results for an iommu_mapped vfio_dma vs
1145          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1146          * time of calling MAP_DMA.
1147          */
1148         get_task_struct(current->group_leader);
1149         dma->task = current->group_leader;
1150         dma->lock_cap = capable(CAP_IPC_LOCK);
1151
1152         dma->pfn_list = RB_ROOT;
1153
1154         /* Insert zero-sized and grow as we map chunks of it */
1155         vfio_link_dma(iommu, dma);
1156
1157         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1158         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1159                 dma->size = size;
1160         else
1161                 ret = vfio_pin_map_dma(iommu, dma, size);
1162
1163 out_unlock:
1164         mutex_unlock(&iommu->lock);
1165         return ret;
1166 }
1167
1168 static int vfio_bus_type(struct device *dev, void *data)
1169 {
1170         struct bus_type **bus = data;
1171
1172         if (*bus && *bus != dev->bus)
1173                 return -EINVAL;
1174
1175         *bus = dev->bus;
1176
1177         return 0;
1178 }
1179
1180 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1181                              struct vfio_domain *domain)
1182 {
1183         struct vfio_domain *d;
1184         struct rb_node *n;
1185         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1186         int ret;
1187
1188         /* Arbitrarily pick the first domain in the list for lookups */
1189         d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
1190         n = rb_first(&iommu->dma_list);
1191
1192         for (; n; n = rb_next(n)) {
1193                 struct vfio_dma *dma;
1194                 dma_addr_t iova;
1195
1196                 dma = rb_entry(n, struct vfio_dma, node);
1197                 iova = dma->iova;
1198
1199                 while (iova < dma->iova + dma->size) {
1200                         phys_addr_t phys;
1201                         size_t size;
1202
1203                         if (dma->iommu_mapped) {
1204                                 phys_addr_t p;
1205                                 dma_addr_t i;
1206
1207                                 phys = iommu_iova_to_phys(d->domain, iova);
1208
1209                                 if (WARN_ON(!phys)) {
1210                                         iova += PAGE_SIZE;
1211                                         continue;
1212                                 }
1213
1214                                 size = PAGE_SIZE;
1215                                 p = phys + size;
1216                                 i = iova + size;
1217                                 while (i < dma->iova + dma->size &&
1218                                        p == iommu_iova_to_phys(d->domain, i)) {
1219                                         size += PAGE_SIZE;
1220                                         p += PAGE_SIZE;
1221                                         i += PAGE_SIZE;
1222                                 }
1223                         } else {
1224                                 unsigned long pfn;
1225                                 unsigned long vaddr = dma->vaddr +
1226                                                      (iova - dma->iova);
1227                                 size_t n = dma->iova + dma->size - iova;
1228                                 long npage;
1229
1230                                 npage = vfio_pin_pages_remote(dma, vaddr,
1231                                                               n >> PAGE_SHIFT,
1232                                                               &pfn, limit);
1233                                 if (npage <= 0) {
1234                                         WARN_ON(!npage);
1235                                         ret = (int)npage;
1236                                         return ret;
1237                                 }
1238
1239                                 phys = pfn << PAGE_SHIFT;
1240                                 size = npage << PAGE_SHIFT;
1241                         }
1242
1243                         ret = iommu_map(domain->domain, iova, phys,
1244                                         size, dma->prot | domain->prot);
1245                         if (ret)
1246                                 return ret;
1247
1248                         iova += size;
1249                 }
1250                 dma->iommu_mapped = true;
1251         }
1252         return 0;
1253 }
1254
1255 /*
1256  * We change our unmap behavior slightly depending on whether the IOMMU
1257  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1258  * for practically any contiguous power-of-two mapping we give it.  This means
1259  * we don't need to look for contiguous chunks ourselves to make unmapping
1260  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1261  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1262  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1263  * hugetlbfs is in use.
1264  */
1265 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1266 {
1267         struct page *pages;
1268         int ret, order = get_order(PAGE_SIZE * 2);
1269
1270         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1271         if (!pages)
1272                 return;
1273
1274         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1275                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1276         if (!ret) {
1277                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1278
1279                 if (unmapped == PAGE_SIZE)
1280                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1281                 else
1282                         domain->fgsp = true;
1283         }
1284
1285         __free_pages(pages, order);
1286 }
1287
1288 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1289                                            struct iommu_group *iommu_group)
1290 {
1291         struct vfio_group *g;
1292
1293         list_for_each_entry(g, &domain->group_list, next) {
1294                 if (g->iommu_group == iommu_group)
1295                         return g;
1296         }
1297
1298         return NULL;
1299 }
1300
1301 static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
1302 {
1303         struct list_head group_resv_regions;
1304         struct iommu_resv_region *region, *next;
1305         bool ret = false;
1306
1307         INIT_LIST_HEAD(&group_resv_regions);
1308         iommu_get_group_resv_regions(group, &group_resv_regions);
1309         list_for_each_entry(region, &group_resv_regions, list) {
1310                 /*
1311                  * The presence of any 'real' MSI regions should take
1312                  * precedence over the software-managed one if the
1313                  * IOMMU driver happens to advertise both types.
1314                  */
1315                 if (region->type == IOMMU_RESV_MSI) {
1316                         ret = false;
1317                         break;
1318                 }
1319
1320                 if (region->type == IOMMU_RESV_SW_MSI) {
1321                         *base = region->start;
1322                         ret = true;
1323                 }
1324         }
1325         list_for_each_entry_safe(region, next, &group_resv_regions, list)
1326                 kfree(region);
1327         return ret;
1328 }
1329
1330 static int vfio_iommu_type1_attach_group(void *iommu_data,
1331                                          struct iommu_group *iommu_group)
1332 {
1333         struct vfio_iommu *iommu = iommu_data;
1334         struct vfio_group *group;
1335         struct vfio_domain *domain, *d;
1336         struct bus_type *bus = NULL, *mdev_bus;
1337         int ret;
1338         bool resv_msi, msi_remap;
1339         phys_addr_t resv_msi_base;
1340
1341         mutex_lock(&iommu->lock);
1342
1343         list_for_each_entry(d, &iommu->domain_list, next) {
1344                 if (find_iommu_group(d, iommu_group)) {
1345                         mutex_unlock(&iommu->lock);
1346                         return -EINVAL;
1347                 }
1348         }
1349
1350         if (iommu->external_domain) {
1351                 if (find_iommu_group(iommu->external_domain, iommu_group)) {
1352                         mutex_unlock(&iommu->lock);
1353                         return -EINVAL;
1354                 }
1355         }
1356
1357         group = kzalloc(sizeof(*group), GFP_KERNEL);
1358         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
1359         if (!group || !domain) {
1360                 ret = -ENOMEM;
1361                 goto out_free;
1362         }
1363
1364         group->iommu_group = iommu_group;
1365
1366         /* Determine bus_type in order to allocate a domain */
1367         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
1368         if (ret)
1369                 goto out_free;
1370
1371         mdev_bus = symbol_get(mdev_bus_type);
1372
1373         if (mdev_bus) {
1374                 if ((bus == mdev_bus) && !iommu_present(bus)) {
1375                         symbol_put(mdev_bus_type);
1376                         if (!iommu->external_domain) {
1377                                 INIT_LIST_HEAD(&domain->group_list);
1378                                 iommu->external_domain = domain;
1379                         } else
1380                                 kfree(domain);
1381
1382                         list_add(&group->next,
1383                                  &iommu->external_domain->group_list);
1384                         mutex_unlock(&iommu->lock);
1385                         return 0;
1386                 }
1387                 symbol_put(mdev_bus_type);
1388         }
1389
1390         domain->domain = iommu_domain_alloc(bus);
1391         if (!domain->domain) {
1392                 ret = -EIO;
1393                 goto out_free;
1394         }
1395
1396         if (iommu->nesting) {
1397                 int attr = 1;
1398
1399                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
1400                                             &attr);
1401                 if (ret)
1402                         goto out_domain;
1403         }
1404
1405         ret = iommu_attach_group(domain->domain, iommu_group);
1406         if (ret)
1407                 goto out_domain;
1408
1409         resv_msi = vfio_iommu_has_sw_msi(iommu_group, &resv_msi_base);
1410
1411         INIT_LIST_HEAD(&domain->group_list);
1412         list_add(&group->next, &domain->group_list);
1413
1414         msi_remap = irq_domain_check_msi_remap() ||
1415                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
1416
1417         if (!allow_unsafe_interrupts && !msi_remap) {
1418                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
1419                        __func__);
1420                 ret = -EPERM;
1421                 goto out_detach;
1422         }
1423
1424         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
1425                 domain->prot |= IOMMU_CACHE;
1426
1427         /*
1428          * Try to match an existing compatible domain.  We don't want to
1429          * preclude an IOMMU driver supporting multiple bus_types and being
1430          * able to include different bus_types in the same IOMMU domain, so
1431          * we test whether the domains use the same iommu_ops rather than
1432          * testing if they're on the same bus_type.
1433          */
1434         list_for_each_entry(d, &iommu->domain_list, next) {
1435                 if (d->domain->ops == domain->domain->ops &&
1436                     d->prot == domain->prot) {
1437                         iommu_detach_group(domain->domain, iommu_group);
1438                         if (!iommu_attach_group(d->domain, iommu_group)) {
1439                                 list_add(&group->next, &d->group_list);
1440                                 iommu_domain_free(domain->domain);
1441                                 kfree(domain);
1442                                 mutex_unlock(&iommu->lock);
1443                                 return 0;
1444                         }
1445
1446                         ret = iommu_attach_group(domain->domain, iommu_group);
1447                         if (ret)
1448                                 goto out_domain;
1449                 }
1450         }
1451
1452         vfio_test_domain_fgsp(domain);
1453
1454         /* replay mappings on new domains */
1455         ret = vfio_iommu_replay(iommu, domain);
1456         if (ret)
1457                 goto out_detach;
1458
1459         if (resv_msi) {
1460                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
1461                 if (ret)
1462                         goto out_detach;
1463         }
1464
1465         list_add(&domain->next, &iommu->domain_list);
1466
1467         mutex_unlock(&iommu->lock);
1468
1469         return 0;
1470
1471 out_detach:
1472         iommu_detach_group(domain->domain, iommu_group);
1473 out_domain:
1474         iommu_domain_free(domain->domain);
1475 out_free:
1476         kfree(domain);
1477         kfree(group);
1478         mutex_unlock(&iommu->lock);
1479         return ret;
1480 }
1481
1482 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
1483 {
1484         struct rb_node *node;
1485
1486         while ((node = rb_first(&iommu->dma_list)))
1487                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
1488 }
1489
1490 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
1491 {
1492         struct rb_node *n, *p;
1493
1494         n = rb_first(&iommu->dma_list);
1495         for (; n; n = rb_next(n)) {
1496                 struct vfio_dma *dma;
1497                 long locked = 0, unlocked = 0;
1498
1499                 dma = rb_entry(n, struct vfio_dma, node);
1500                 unlocked += vfio_unmap_unpin(iommu, dma, false);
1501                 p = rb_first(&dma->pfn_list);
1502                 for (; p; p = rb_next(p)) {
1503                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
1504                                                          node);
1505
1506                         if (!is_invalid_reserved_pfn(vpfn->pfn))
1507                                 locked++;
1508                 }
1509                 vfio_lock_acct(dma, locked - unlocked, true);
1510         }
1511 }
1512
1513 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
1514 {
1515         struct rb_node *n;
1516
1517         n = rb_first(&iommu->dma_list);
1518         for (; n; n = rb_next(n)) {
1519                 struct vfio_dma *dma;
1520
1521                 dma = rb_entry(n, struct vfio_dma, node);
1522
1523                 if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
1524                         break;
1525         }
1526         /* mdev vendor driver must unregister notifier */
1527         WARN_ON(iommu->notifier.head);
1528 }
1529
1530 static void vfio_iommu_type1_detach_group(void *iommu_data,
1531                                           struct iommu_group *iommu_group)
1532 {
1533         struct vfio_iommu *iommu = iommu_data;
1534         struct vfio_domain *domain;
1535         struct vfio_group *group;
1536
1537         mutex_lock(&iommu->lock);
1538
1539         if (iommu->external_domain) {
1540                 group = find_iommu_group(iommu->external_domain, iommu_group);
1541                 if (group) {
1542                         list_del(&group->next);
1543                         kfree(group);
1544
1545                         if (list_empty(&iommu->external_domain->group_list)) {
1546                                 vfio_sanity_check_pfn_list(iommu);
1547
1548                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1549                                         vfio_iommu_unmap_unpin_all(iommu);
1550
1551                                 kfree(iommu->external_domain);
1552                                 iommu->external_domain = NULL;
1553                         }
1554                         goto detach_group_done;
1555                 }
1556         }
1557
1558         list_for_each_entry(domain, &iommu->domain_list, next) {
1559                 group = find_iommu_group(domain, iommu_group);
1560                 if (!group)
1561                         continue;
1562
1563                 iommu_detach_group(domain->domain, iommu_group);
1564                 list_del(&group->next);
1565                 kfree(group);
1566                 /*
1567                  * Group ownership provides privilege, if the group list is
1568                  * empty, the domain goes away. If it's the last domain with
1569                  * iommu and external domain doesn't exist, then all the
1570                  * mappings go away too. If it's the last domain with iommu and
1571                  * external domain exist, update accounting
1572                  */
1573                 if (list_empty(&domain->group_list)) {
1574                         if (list_is_singular(&iommu->domain_list)) {
1575                                 if (!iommu->external_domain)
1576                                         vfio_iommu_unmap_unpin_all(iommu);
1577                                 else
1578                                         vfio_iommu_unmap_unpin_reaccount(iommu);
1579                         }
1580                         iommu_domain_free(domain->domain);
1581                         list_del(&domain->next);
1582                         kfree(domain);
1583                 }
1584                 break;
1585         }
1586
1587 detach_group_done:
1588         mutex_unlock(&iommu->lock);
1589 }
1590
1591 static void *vfio_iommu_type1_open(unsigned long arg)
1592 {
1593         struct vfio_iommu *iommu;
1594
1595         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
1596         if (!iommu)
1597                 return ERR_PTR(-ENOMEM);
1598
1599         switch (arg) {
1600         case VFIO_TYPE1_IOMMU:
1601                 break;
1602         case VFIO_TYPE1_NESTING_IOMMU:
1603                 iommu->nesting = true;
1604         case VFIO_TYPE1v2_IOMMU:
1605                 iommu->v2 = true;
1606                 break;
1607         default:
1608                 kfree(iommu);
1609                 return ERR_PTR(-EINVAL);
1610         }
1611
1612         INIT_LIST_HEAD(&iommu->domain_list);
1613         iommu->dma_list = RB_ROOT;
1614         mutex_init(&iommu->lock);
1615         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
1616
1617         return iommu;
1618 }
1619
1620 static void vfio_release_domain(struct vfio_domain *domain, bool external)
1621 {
1622         struct vfio_group *group, *group_tmp;
1623
1624         list_for_each_entry_safe(group, group_tmp,
1625                                  &domain->group_list, next) {
1626                 if (!external)
1627                         iommu_detach_group(domain->domain, group->iommu_group);
1628                 list_del(&group->next);
1629                 kfree(group);
1630         }
1631
1632         if (!external)
1633                 iommu_domain_free(domain->domain);
1634 }
1635
1636 static void vfio_iommu_type1_release(void *iommu_data)
1637 {
1638         struct vfio_iommu *iommu = iommu_data;
1639         struct vfio_domain *domain, *domain_tmp;
1640
1641         if (iommu->external_domain) {
1642                 vfio_release_domain(iommu->external_domain, true);
1643                 vfio_sanity_check_pfn_list(iommu);
1644                 kfree(iommu->external_domain);
1645         }
1646
1647         vfio_iommu_unmap_unpin_all(iommu);
1648
1649         list_for_each_entry_safe(domain, domain_tmp,
1650                                  &iommu->domain_list, next) {
1651                 vfio_release_domain(domain, false);
1652                 list_del(&domain->next);
1653                 kfree(domain);
1654         }
1655         kfree(iommu);
1656 }
1657
1658 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
1659 {
1660         struct vfio_domain *domain;
1661         int ret = 1;
1662
1663         mutex_lock(&iommu->lock);
1664         list_for_each_entry(domain, &iommu->domain_list, next) {
1665                 if (!(domain->prot & IOMMU_CACHE)) {
1666                         ret = 0;
1667                         break;
1668                 }
1669         }
1670         mutex_unlock(&iommu->lock);
1671
1672         return ret;
1673 }
1674
1675 static long vfio_iommu_type1_ioctl(void *iommu_data,
1676                                    unsigned int cmd, unsigned long arg)
1677 {
1678         struct vfio_iommu *iommu = iommu_data;
1679         unsigned long minsz;
1680
1681         if (cmd == VFIO_CHECK_EXTENSION) {
1682                 switch (arg) {
1683                 case VFIO_TYPE1_IOMMU:
1684                 case VFIO_TYPE1v2_IOMMU:
1685                 case VFIO_TYPE1_NESTING_IOMMU:
1686                         return 1;
1687                 case VFIO_DMA_CC_IOMMU:
1688                         if (!iommu)
1689                                 return 0;
1690                         return vfio_domains_have_iommu_cache(iommu);
1691                 default:
1692                         return 0;
1693                 }
1694         } else if (cmd == VFIO_IOMMU_GET_INFO) {
1695                 struct vfio_iommu_type1_info info;
1696
1697                 minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
1698
1699                 if (copy_from_user(&info, (void __user *)arg, minsz))
1700                         return -EFAULT;
1701
1702                 if (info.argsz < minsz)
1703                         return -EINVAL;
1704
1705                 info.flags = VFIO_IOMMU_INFO_PGSIZES;
1706
1707                 info.iova_pgsizes = vfio_pgsize_bitmap(iommu);
1708
1709                 return copy_to_user((void __user *)arg, &info, minsz) ?
1710                         -EFAULT : 0;
1711
1712         } else if (cmd == VFIO_IOMMU_MAP_DMA) {
1713                 struct vfio_iommu_type1_dma_map map;
1714                 uint32_t mask = VFIO_DMA_MAP_FLAG_READ |
1715                                 VFIO_DMA_MAP_FLAG_WRITE;
1716
1717                 minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
1718
1719                 if (copy_from_user(&map, (void __user *)arg, minsz))
1720                         return -EFAULT;
1721
1722                 if (map.argsz < minsz || map.flags & ~mask)
1723                         return -EINVAL;
1724
1725                 return vfio_dma_do_map(iommu, &map);
1726
1727         } else if (cmd == VFIO_IOMMU_UNMAP_DMA) {
1728                 struct vfio_iommu_type1_dma_unmap unmap;
1729                 long ret;
1730
1731                 minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
1732
1733                 if (copy_from_user(&unmap, (void __user *)arg, minsz))
1734                         return -EFAULT;
1735
1736                 if (unmap.argsz < minsz || unmap.flags)
1737                         return -EINVAL;
1738
1739                 ret = vfio_dma_do_unmap(iommu, &unmap);
1740                 if (ret)
1741                         return ret;
1742
1743                 return copy_to_user((void __user *)arg, &unmap, minsz) ?
1744                         -EFAULT : 0;
1745         }
1746
1747         return -ENOTTY;
1748 }
1749
1750 static int vfio_iommu_type1_register_notifier(void *iommu_data,
1751                                               unsigned long *events,
1752                                               struct notifier_block *nb)
1753 {
1754         struct vfio_iommu *iommu = iommu_data;
1755
1756         /* clear known events */
1757         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
1758
1759         /* refuse to register if still events remaining */
1760         if (*events)
1761                 return -EINVAL;
1762
1763         return blocking_notifier_chain_register(&iommu->notifier, nb);
1764 }
1765
1766 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
1767                                                 struct notifier_block *nb)
1768 {
1769         struct vfio_iommu *iommu = iommu_data;
1770
1771         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
1772 }
1773
1774 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
1775         .name                   = "vfio-iommu-type1",
1776         .owner                  = THIS_MODULE,
1777         .open                   = vfio_iommu_type1_open,
1778         .release                = vfio_iommu_type1_release,
1779         .ioctl                  = vfio_iommu_type1_ioctl,
1780         .attach_group           = vfio_iommu_type1_attach_group,
1781         .detach_group           = vfio_iommu_type1_detach_group,
1782         .pin_pages              = vfio_iommu_type1_pin_pages,
1783         .unpin_pages            = vfio_iommu_type1_unpin_pages,
1784         .register_notifier      = vfio_iommu_type1_register_notifier,
1785         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
1786 };
1787
1788 static int __init vfio_iommu_type1_init(void)
1789 {
1790         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
1791 }
1792
1793 static void __exit vfio_iommu_type1_cleanup(void)
1794 {
1795         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
1796 }
1797
1798 module_init(vfio_iommu_type1_init);
1799 module_exit(vfio_iommu_type1_cleanup);
1800
1801 MODULE_VERSION(DRIVER_VERSION);
1802 MODULE_LICENSE("GPL v2");
1803 MODULE_AUTHOR(DRIVER_AUTHOR);
1804 MODULE_DESCRIPTION(DRIVER_DESC);