Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[sfrench/cifs-2.6.git] / mm / hmm.c
index 320545b98ff55997029476f32e09dbf5d4f5f009..486dc394a5a3cd1fe226e215717631619c8a4195 100644 (file)
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -160,6 +160,32 @@ static void hmm_invalidate_range(struct hmm *hmm,
        up_read(&hmm->mirrors_sem);
 }
 
+static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
+{
+       struct hmm_mirror *mirror;
+       struct hmm *hmm = mm->hmm;
+
+       down_write(&hmm->mirrors_sem);
+       mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
+                                         list);
+       while (mirror) {
+               list_del_init(&mirror->list);
+               if (mirror->ops->release) {
+                       /*
+                        * Drop mirrors_sem so callback can wait on any pending
+                        * work that might itself trigger mmu_notifier callback
+                        * and thus would deadlock with us.
+                        */
+                       up_write(&hmm->mirrors_sem);
+                       mirror->ops->release(mirror);
+                       down_write(&hmm->mirrors_sem);
+               }
+               mirror = list_first_entry_or_null(&hmm->mirrors,
+                                                 struct hmm_mirror, list);
+       }
+       up_write(&hmm->mirrors_sem);
+}
+
 static void hmm_invalidate_range_start(struct mmu_notifier *mn,
                                       struct mm_struct *mm,
                                       unsigned long start,
@@ -185,6 +211,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
 }
 
 static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
+       .release                = hmm_release,
        .invalidate_range_start = hmm_invalidate_range_start,
        .invalidate_range_end   = hmm_invalidate_range_end,
 };
@@ -206,13 +233,24 @@ int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
        if (!mm || !mirror || !mirror->ops)
                return -EINVAL;
 
+again:
        mirror->hmm = hmm_register(mm);
        if (!mirror->hmm)
                return -ENOMEM;
 
        down_write(&mirror->hmm->mirrors_sem);
-       list_add(&mirror->list, &mirror->hmm->mirrors);
-       up_write(&mirror->hmm->mirrors_sem);
+       if (mirror->hmm->mm == NULL) {
+               /*
+                * A racing hmm_mirror_unregister() is about to destroy the hmm
+                * struct. Try again to allocate a new one.
+                */
+               up_write(&mirror->hmm->mirrors_sem);
+               mirror->hmm = NULL;
+               goto again;
+       } else {
+               list_add(&mirror->list, &mirror->hmm->mirrors);
+               up_write(&mirror->hmm->mirrors_sem);
+       }
 
        return 0;
 }
@@ -227,11 +265,32 @@ EXPORT_SYMBOL(hmm_mirror_register);
  */
 void hmm_mirror_unregister(struct hmm_mirror *mirror)
 {
-       struct hmm *hmm = mirror->hmm;
+       bool should_unregister = false;
+       struct mm_struct *mm;
+       struct hmm *hmm;
 
+       if (mirror->hmm == NULL)
+               return;
+
+       hmm = mirror->hmm;
        down_write(&hmm->mirrors_sem);
-       list_del(&mirror->list);
+       list_del_init(&mirror->list);
+       should_unregister = list_empty(&hmm->mirrors);
+       mirror->hmm = NULL;
+       mm = hmm->mm;
+       hmm->mm = NULL;
        up_write(&hmm->mirrors_sem);
+
+       if (!should_unregister || mm == NULL)
+               return;
+
+       spin_lock(&mm->page_table_lock);
+       if (mm->hmm == hmm)
+               mm->hmm = NULL;
+       spin_unlock(&mm->page_table_lock);
+
+       mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
+       kfree(hmm);
 }
 EXPORT_SYMBOL(hmm_mirror_unregister);
 
@@ -240,110 +299,275 @@ struct hmm_vma_walk {
        unsigned long           last;
        bool                    fault;
        bool                    block;
-       bool                    write;
 };
 
-static int hmm_vma_do_fault(struct mm_walk *walk,
-                           unsigned long addr,
-                           hmm_pfn_t *pfn)
+static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
+                           bool write_fault, uint64_t *pfn)
 {
        unsigned int flags = FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_REMOTE;
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
+       struct hmm_range *range = hmm_vma_walk->range;
        struct vm_area_struct *vma = walk->vma;
        int r;
 
        flags |= hmm_vma_walk->block ? 0 : FAULT_FLAG_ALLOW_RETRY;
-       flags |= hmm_vma_walk->write ? FAULT_FLAG_WRITE : 0;
+       flags |= write_fault ? FAULT_FLAG_WRITE : 0;
        r = handle_mm_fault(vma, addr, flags);
        if (r & VM_FAULT_RETRY)
                return -EBUSY;
        if (r & VM_FAULT_ERROR) {
-               *pfn = HMM_PFN_ERROR;
+               *pfn = range->values[HMM_PFN_ERROR];
                return -EFAULT;
        }
 
        return -EAGAIN;
 }
 
-static void hmm_pfns_special(hmm_pfn_t *pfns,
-                            unsigned long addr,
-                            unsigned long end)
-{
-       for (; addr < end; addr += PAGE_SIZE, pfns++)
-               *pfns = HMM_PFN_SPECIAL;
-}
-
 static int hmm_pfns_bad(unsigned long addr,
                        unsigned long end,
                        struct mm_walk *walk)
 {
-       struct hmm_range *range = walk->private;
-       hmm_pfn_t *pfns = range->pfns;
+       struct hmm_vma_walk *hmm_vma_walk = walk->private;
+       struct hmm_range *range = hmm_vma_walk->range;
+       uint64_t *pfns = range->pfns;
        unsigned long i;
 
        i = (addr - range->start) >> PAGE_SHIFT;
        for (; addr < end; addr += PAGE_SIZE, i++)
-               pfns[i] = HMM_PFN_ERROR;
+               pfns[i] = range->values[HMM_PFN_ERROR];
 
        return 0;
 }
 
-static void hmm_pfns_clear(hmm_pfn_t *pfns,
-                          unsigned long addr,
-                          unsigned long end)
-{
-       for (; addr < end; addr += PAGE_SIZE, pfns++)
-               *pfns = 0;
-}
-
-static int hmm_vma_walk_hole(unsigned long addr,
-                            unsigned long end,
-                            struct mm_walk *walk)
+/*
+ * hmm_vma_walk_hole() - handle a range lacking valid pmd or pte(s)
+ * @start: range virtual start address (inclusive)
+ * @end: range virtual end address (exclusive)
+ * @fault: should we fault or not ?
+ * @write_fault: write fault ?
+ * @walk: mm_walk structure
+ * Returns: 0 on success, -EAGAIN after page fault, or page fault error
+ *
+ * This function will be called whenever pmd_none() or pte_none() returns true,
+ * or whenever there is no page directory covering the virtual address range.
+ */
+static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
+                             bool fault, bool write_fault,
+                             struct mm_walk *walk)
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       hmm_pfn_t *pfns = range->pfns;
+       uint64_t *pfns = range->pfns;
        unsigned long i;
 
        hmm_vma_walk->last = addr;
        i = (addr - range->start) >> PAGE_SHIFT;
        for (; addr < end; addr += PAGE_SIZE, i++) {
-               pfns[i] = HMM_PFN_EMPTY;
-               if (hmm_vma_walk->fault) {
+               pfns[i] = range->values[HMM_PFN_NONE];
+               if (fault || write_fault) {
                        int ret;
 
-                       ret = hmm_vma_do_fault(walk, addr, &pfns[i]);
+                       ret = hmm_vma_do_fault(walk, addr, write_fault,
+                                              &pfns[i]);
                        if (ret != -EAGAIN)
                                return ret;
                }
        }
 
-       return hmm_vma_walk->fault ? -EAGAIN : 0;
+       return (fault || write_fault) ? -EAGAIN : 0;
 }
 
-static int hmm_vma_walk_clear(unsigned long addr,
-                             unsigned long end,
-                             struct mm_walk *walk)
+static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
+                                     uint64_t pfns, uint64_t cpu_flags,
+                                     bool *fault, bool *write_fault)
 {
-       struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       hmm_pfn_t *pfns = range->pfns;
+
+       *fault = *write_fault = false;
+       if (!hmm_vma_walk->fault)
+               return;
+
+       /* We aren't ask to do anything ... */
+       if (!(pfns & range->flags[HMM_PFN_VALID]))
+               return;
+       /* If this is device memory than only fault if explicitly requested */
+       if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
+               /* Do we fault on device memory ? */
+               if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
+                       *write_fault = pfns & range->flags[HMM_PFN_WRITE];
+                       *fault = true;
+               }
+               return;
+       }
+
+       /* If CPU page table is not valid then we need to fault */
+       *fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
+       /* Need to write fault ? */
+       if ((pfns & range->flags[HMM_PFN_WRITE]) &&
+           !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
+               *write_fault = true;
+               *fault = true;
+       }
+}
+
+static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
+                                const uint64_t *pfns, unsigned long npages,
+                                uint64_t cpu_flags, bool *fault,
+                                bool *write_fault)
+{
        unsigned long i;
 
-       hmm_vma_walk->last = addr;
+       if (!hmm_vma_walk->fault) {
+               *fault = *write_fault = false;
+               return;
+       }
+
+       for (i = 0; i < npages; ++i) {
+               hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
+                                  fault, write_fault);
+               if ((*fault) || (*write_fault))
+                       return;
+       }
+}
+
+static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
+                            struct mm_walk *walk)
+{
+       struct hmm_vma_walk *hmm_vma_walk = walk->private;
+       struct hmm_range *range = hmm_vma_walk->range;
+       bool fault, write_fault;
+       unsigned long i, npages;
+       uint64_t *pfns;
+
        i = (addr - range->start) >> PAGE_SHIFT;
-       for (; addr < end; addr += PAGE_SIZE, i++) {
-               pfns[i] = 0;
-               if (hmm_vma_walk->fault) {
-                       int ret;
+       npages = (end - addr) >> PAGE_SHIFT;
+       pfns = &range->pfns[i];
+       hmm_range_need_fault(hmm_vma_walk, pfns, npages,
+                            0, &fault, &write_fault);
+       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+}
 
-                       ret = hmm_vma_do_fault(walk, addr, &pfns[i]);
-                       if (ret != -EAGAIN)
-                               return ret;
+static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
+{
+       if (pmd_protnone(pmd))
+               return 0;
+       return pmd_write(pmd) ? range->flags[HMM_PFN_VALID] |
+                               range->flags[HMM_PFN_WRITE] :
+                               range->flags[HMM_PFN_VALID];
+}
+
+static int hmm_vma_handle_pmd(struct mm_walk *walk,
+                             unsigned long addr,
+                             unsigned long end,
+                             uint64_t *pfns,
+                             pmd_t pmd)
+{
+       struct hmm_vma_walk *hmm_vma_walk = walk->private;
+       struct hmm_range *range = hmm_vma_walk->range;
+       unsigned long pfn, npages, i;
+       bool fault, write_fault;
+       uint64_t cpu_flags;
+
+       npages = (end - addr) >> PAGE_SHIFT;
+       cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
+       hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
+                            &fault, &write_fault);
+
+       if (pmd_protnone(pmd) || fault || write_fault)
+               return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+
+       pfn = pmd_pfn(pmd) + pte_index(addr);
+       for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++)
+               pfns[i] = hmm_pfn_from_pfn(range, pfn) | cpu_flags;
+       hmm_vma_walk->last = end;
+       return 0;
+}
+
+static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
+{
+       if (pte_none(pte) || !pte_present(pte))
+               return 0;
+       return pte_write(pte) ? range->flags[HMM_PFN_VALID] |
+                               range->flags[HMM_PFN_WRITE] :
+                               range->flags[HMM_PFN_VALID];
+}
+
+static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
+                             unsigned long end, pmd_t *pmdp, pte_t *ptep,
+                             uint64_t *pfn)
+{
+       struct hmm_vma_walk *hmm_vma_walk = walk->private;
+       struct hmm_range *range = hmm_vma_walk->range;
+       struct vm_area_struct *vma = walk->vma;
+       bool fault, write_fault;
+       uint64_t cpu_flags;
+       pte_t pte = *ptep;
+       uint64_t orig_pfn = *pfn;
+
+       *pfn = range->values[HMM_PFN_NONE];
+       cpu_flags = pte_to_hmm_pfn_flags(range, pte);
+       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
+                          &fault, &write_fault);
+
+       if (pte_none(pte)) {
+               if (fault || write_fault)
+                       goto fault;
+               return 0;
+       }
+
+       if (!pte_present(pte)) {
+               swp_entry_t entry = pte_to_swp_entry(pte);
+
+               if (!non_swap_entry(entry)) {
+                       if (fault || write_fault)
+                               goto fault;
+                       return 0;
                }
+
+               /*
+                * This is a special swap entry, ignore migration, use
+                * device and report anything else as error.
+                */
+               if (is_device_private_entry(entry)) {
+                       cpu_flags = range->flags[HMM_PFN_VALID] |
+                               range->flags[HMM_PFN_DEVICE_PRIVATE];
+                       cpu_flags |= is_write_device_private_entry(entry) ?
+                               range->flags[HMM_PFN_WRITE] : 0;
+                       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
+                                          &fault, &write_fault);
+                       if (fault || write_fault)
+                               goto fault;
+                       *pfn = hmm_pfn_from_pfn(range, swp_offset(entry));
+                       *pfn |= cpu_flags;
+                       return 0;
+               }
+
+               if (is_migration_entry(entry)) {
+                       if (fault || write_fault) {
+                               pte_unmap(ptep);
+                               hmm_vma_walk->last = addr;
+                               migration_entry_wait(vma->vm_mm,
+                                                    pmdp, addr);
+                               return -EAGAIN;
+                       }
+                       return 0;
+               }
+
+               /* Report error for everything else */
+               *pfn = range->values[HMM_PFN_ERROR];
+               return -EFAULT;
        }
 
-       return hmm_vma_walk->fault ? -EAGAIN : 0;
+       if (fault || write_fault)
+               goto fault;
+
+       *pfn = hmm_pfn_from_pfn(range, pte_pfn(pte)) | cpu_flags;
+       return 0;
+
+fault:
+       pte_unmap(ptep);
+       /* Fault any virtual address we were asked to fault */
+       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
 }
 
 static int hmm_vma_walk_pmd(pmd_t *pmdp,
@@ -353,26 +577,20 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       struct vm_area_struct *vma = walk->vma;
-       hmm_pfn_t *pfns = range->pfns;
+       uint64_t *pfns = range->pfns;
        unsigned long addr = start, i;
-       bool write_fault;
-       hmm_pfn_t flag;
        pte_t *ptep;
 
        i = (addr - range->start) >> PAGE_SHIFT;
-       flag = vma->vm_flags & VM_READ ? HMM_PFN_READ : 0;
-       write_fault = hmm_vma_walk->fault & hmm_vma_walk->write;
 
 again:
        if (pmd_none(*pmdp))
                return hmm_vma_walk_hole(start, end, walk);
 
-       if (pmd_huge(*pmdp) && vma->vm_flags & VM_HUGETLB)
+       if (pmd_huge(*pmdp) && (range->vma->vm_flags & VM_HUGETLB))
                return hmm_pfns_bad(start, end, walk);
 
        if (pmd_devmap(*pmdp) || pmd_trans_huge(*pmdp)) {
-               unsigned long pfn;
                pmd_t pmd;
 
                /*
@@ -388,17 +606,8 @@ again:
                barrier();
                if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
                        goto again;
-               if (pmd_protnone(pmd))
-                       return hmm_vma_walk_clear(start, end, walk);
 
-               if (write_fault && !pmd_write(pmd))
-                       return hmm_vma_walk_clear(start, end, walk);
-
-               pfn = pmd_pfn(pmd) + pte_index(addr);
-               flag |= pmd_write(pmd) ? HMM_PFN_WRITE : 0;
-               for (; addr < end; addr += PAGE_SIZE, i++, pfn++)
-                       pfns[i] = hmm_pfn_t_from_pfn(pfn) | flag;
-               return 0;
+               return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
        }
 
        if (pmd_bad(*pmdp))
@@ -406,79 +615,43 @@ again:
 
        ptep = pte_offset_map(pmdp, addr);
        for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
-               pte_t pte = *ptep;
-
-               pfns[i] = 0;
+               int r;
 
-               if (pte_none(pte)) {
-                       pfns[i] = HMM_PFN_EMPTY;
-                       if (hmm_vma_walk->fault)
-                               goto fault;
-                       continue;
+               r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, &pfns[i]);
+               if (r) {
+                       /* hmm_vma_handle_pte() did unmap pte directory */
+                       hmm_vma_walk->last = addr;
+                       return r;
                }
-
-               if (!pte_present(pte)) {
-                       swp_entry_t entry = pte_to_swp_entry(pte);
-
-                       if (!non_swap_entry(entry)) {
-                               if (hmm_vma_walk->fault)
-                                       goto fault;
-                               continue;
-                       }
-
-                       /*
-                        * This is a special swap entry, ignore migration, use
-                        * device and report anything else as error.
-                        */
-                       if (is_device_private_entry(entry)) {
-                               pfns[i] = hmm_pfn_t_from_pfn(swp_offset(entry));
-                               if (is_write_device_private_entry(entry)) {
-                                       pfns[i] |= HMM_PFN_WRITE;
-                               } else if (write_fault)
-                                       goto fault;
-                               pfns[i] |= HMM_PFN_DEVICE_UNADDRESSABLE;
-                               pfns[i] |= flag;
-                       } else if (is_migration_entry(entry)) {
-                               if (hmm_vma_walk->fault) {
-                                       pte_unmap(ptep);
-                                       hmm_vma_walk->last = addr;
-                                       migration_entry_wait(vma->vm_mm,
-                                                            pmdp, addr);
-                                       return -EAGAIN;
-                               }
-                               continue;
-                       } else {
-                               /* Report error for everything else */
-                               pfns[i] = HMM_PFN_ERROR;
-                       }
-                       continue;
-               }
-
-               if (write_fault && !pte_write(pte))
-                       goto fault;
-
-               pfns[i] = hmm_pfn_t_from_pfn(pte_pfn(pte)) | flag;
-               pfns[i] |= pte_write(pte) ? HMM_PFN_WRITE : 0;
-               continue;
-
-fault:
-               pte_unmap(ptep);
-               /* Fault all pages in range */
-               return hmm_vma_walk_clear(start, end, walk);
        }
        pte_unmap(ptep - 1);
 
+       hmm_vma_walk->last = addr;
        return 0;
 }
 
+static void hmm_pfns_clear(struct hmm_range *range,
+                          uint64_t *pfns,
+                          unsigned long addr,
+                          unsigned long end)
+{
+       for (; addr < end; addr += PAGE_SIZE, pfns++)
+               *pfns = range->values[HMM_PFN_NONE];
+}
+
+static void hmm_pfns_special(struct hmm_range *range)
+{
+       unsigned long addr = range->start, i = 0;
+
+       for (; addr < range->end; addr += PAGE_SIZE, i++)
+               range->pfns[i] = range->values[HMM_PFN_SPECIAL];
+}
+
 /*
  * hmm_vma_get_pfns() - snapshot CPU page table for a range of virtual addresses
- * @vma: virtual memory area containing the virtual address range
- * @range: used to track snapshot validity
- * @start: range virtual start address (inclusive)
- * @end: range virtual end address (exclusive)
- * @entries: array of hmm_pfn_t: provided by the caller, filled in by function
- * Returns: -EINVAL if invalid argument, -ENOMEM out of memory, 0 success
+ * @range: range being snapshotted
+ * Returns: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
+ *          vma permission, 0 success
  *
  * This snapshots the CPU page table for a range of virtual addresses. Snapshot
  * validity is tracked by range struct. See hmm_vma_range_done() for further
@@ -491,26 +664,17 @@ fault:
  * NOT CALLING hmm_vma_range_done() IF FUNCTION RETURNS 0 WILL LEAD TO SERIOUS
  * MEMORY CORRUPTION ! YOU HAVE BEEN WARNED !
  */
-int hmm_vma_get_pfns(struct vm_area_struct *vma,
-                    struct hmm_range *range,
-                    unsigned long start,
-                    unsigned long end,
-                    hmm_pfn_t *pfns)
+int hmm_vma_get_pfns(struct hmm_range *range)
 {
+       struct vm_area_struct *vma = range->vma;
        struct hmm_vma_walk hmm_vma_walk;
        struct mm_walk mm_walk;
        struct hmm *hmm;
 
-       /* FIXME support hugetlb fs */
-       if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL)) {
-               hmm_pfns_special(pfns, start, end);
-               return -EINVAL;
-       }
-
        /* Sanity check, this really should not happen ! */
-       if (start < vma->vm_start || start >= vma->vm_end)
+       if (range->start < vma->vm_start || range->start >= vma->vm_end)
                return -EINVAL;
-       if (end < vma->vm_start || end > vma->vm_end)
+       if (range->end < vma->vm_start || range->end > vma->vm_end)
                return -EINVAL;
 
        hmm = hmm_register(vma->vm_mm);
@@ -520,10 +684,24 @@ int hmm_vma_get_pfns(struct vm_area_struct *vma,
        if (!hmm->mmu_notifier.ops)
                return -EINVAL;
 
+       /* FIXME support hugetlb fs */
+       if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL)) {
+               hmm_pfns_special(range);
+               return -EINVAL;
+       }
+
+       if (!(vma->vm_flags & VM_READ)) {
+               /*
+                * If vma do not allow read access, then assume that it does
+                * not allow write access, either. Architecture that allow
+                * write without read access are not supported by HMM, because
+                * operations such has atomic access would not work.
+                */
+               hmm_pfns_clear(range, range->pfns, range->start, range->end);
+               return -EPERM;
+       }
+
        /* Initialize range to track CPU page table update */
-       range->start = start;
-       range->pfns = pfns;
-       range->end = end;
        spin_lock(&hmm->lock);
        range->valid = true;
        list_add_rcu(&range->list, &hmm->ranges);
@@ -541,14 +719,13 @@ int hmm_vma_get_pfns(struct vm_area_struct *vma,
        mm_walk.pmd_entry = hmm_vma_walk_pmd;
        mm_walk.pte_hole = hmm_vma_walk_hole;
 
-       walk_page_range(start, end, &mm_walk);
+       walk_page_range(range->start, range->end, &mm_walk);
        return 0;
 }
 EXPORT_SYMBOL(hmm_vma_get_pfns);
 
 /*
  * hmm_vma_range_done() - stop tracking change to CPU page table over a range
- * @vma: virtual memory area containing the virtual address range
  * @range: range being tracked
  * Returns: false if range data has been invalidated, true otherwise
  *
@@ -568,10 +745,10 @@ EXPORT_SYMBOL(hmm_vma_get_pfns);
  *
  * There are two ways to use this :
  * again:
- *   hmm_vma_get_pfns(vma, range, start, end, pfns); or hmm_vma_fault(...);
+ *   hmm_vma_get_pfns(range); or hmm_vma_fault(...);
  *   trans = device_build_page_table_update_transaction(pfns);
  *   device_page_table_lock();
- *   if (!hmm_vma_range_done(vma, range)) {
+ *   if (!hmm_vma_range_done(range)) {
  *     device_page_table_unlock();
  *     goto again;
  *   }
@@ -579,13 +756,13 @@ EXPORT_SYMBOL(hmm_vma_get_pfns);
  *   device_page_table_unlock();
  *
  * Or:
- *   hmm_vma_get_pfns(vma, range, start, end, pfns); or hmm_vma_fault(...);
+ *   hmm_vma_get_pfns(range); or hmm_vma_fault(...);
  *   device_page_table_lock();
- *   hmm_vma_range_done(vma, range);
- *   device_update_page_table(pfns);
+ *   hmm_vma_range_done(range);
+ *   device_update_page_table(range->pfns);
  *   device_page_table_unlock();
  */
-bool hmm_vma_range_done(struct vm_area_struct *vma, struct hmm_range *range)
+bool hmm_vma_range_done(struct hmm_range *range)
 {
        unsigned long npages = (range->end - range->start) >> PAGE_SHIFT;
        struct hmm *hmm;
@@ -595,7 +772,7 @@ bool hmm_vma_range_done(struct vm_area_struct *vma, struct hmm_range *range)
                return false;
        }
 
-       hmm = hmm_register(vma->vm_mm);
+       hmm = hmm_register(range->vma->vm_mm);
        if (!hmm) {
                memset(range->pfns, 0, sizeof(*range->pfns) * npages);
                return false;
@@ -611,36 +788,34 @@ EXPORT_SYMBOL(hmm_vma_range_done);
 
 /*
  * hmm_vma_fault() - try to fault some address in a virtual address range
- * @vma: virtual memory area containing the virtual address range
- * @range: use to track pfns array content validity
- * @start: fault range virtual start address (inclusive)
- * @end: fault range virtual end address (exclusive)
- * @pfns: array of hmm_pfn_t, only entry with fault flag set will be faulted
- * @write: is it a write fault
+ * @range: range being faulted
  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
  * Returns: 0 success, error otherwise (-EAGAIN means mmap_sem have been drop)
  *
  * This is similar to a regular CPU page fault except that it will not trigger
  * any memory migration if the memory being faulted is not accessible by CPUs.
  *
- * On error, for one virtual address in the range, the function will set the
- * hmm_pfn_t error flag for the corresponding pfn entry.
+ * On error, for one virtual address in the range, the function will mark the
+ * corresponding HMM pfn entry with an error flag.
  *
  * Expected use pattern:
  * retry:
  *   down_read(&mm->mmap_sem);
  *   // Find vma and address device wants to fault, initialize hmm_pfn_t
  *   // array accordingly
- *   ret = hmm_vma_fault(vma, start, end, pfns, allow_retry);
+ *   ret = hmm_vma_fault(range, write, block);
  *   switch (ret) {
  *   case -EAGAIN:
- *     hmm_vma_range_done(vma, range);
+ *     hmm_vma_range_done(range);
  *     // You might want to rate limit or yield to play nicely, you may
  *     // also commit any valid pfn in the array assuming that you are
  *     // getting true from hmm_vma_range_monitor_end()
  *     goto retry;
  *   case 0:
  *     break;
+ *   case -ENOMEM:
+ *   case -EINVAL:
+ *   case -EPERM:
  *   default:
  *     // Handle error !
  *     up_read(&mm->mmap_sem)
@@ -648,7 +823,7 @@ EXPORT_SYMBOL(hmm_vma_range_done);
  *   }
  *   // Take device driver lock that serialize device page table update
  *   driver_lock_device_page_table_update();
- *   hmm_vma_range_done(vma, range);
+ *   hmm_vma_range_done(range);
  *   // Commit pfns we got from hmm_vma_fault()
  *   driver_unlock_device_page_table_update();
  *   up_read(&mm->mmap_sem)
@@ -658,51 +833,54 @@ EXPORT_SYMBOL(hmm_vma_range_done);
  *
  * YOU HAVE BEEN WARNED !
  */
-int hmm_vma_fault(struct vm_area_struct *vma,
-                 struct hmm_range *range,
-                 unsigned long start,
-                 unsigned long end,
-                 hmm_pfn_t *pfns,
-                 bool write,
-                 bool block)
+int hmm_vma_fault(struct hmm_range *range, bool block)
 {
+       struct vm_area_struct *vma = range->vma;
+       unsigned long start = range->start;
        struct hmm_vma_walk hmm_vma_walk;
        struct mm_walk mm_walk;
        struct hmm *hmm;
        int ret;
 
        /* Sanity check, this really should not happen ! */
-       if (start < vma->vm_start || start >= vma->vm_end)
+       if (range->start < vma->vm_start || range->start >= vma->vm_end)
                return -EINVAL;
-       if (end < vma->vm_start || end > vma->vm_end)
+       if (range->end < vma->vm_start || range->end > vma->vm_end)
                return -EINVAL;
 
        hmm = hmm_register(vma->vm_mm);
        if (!hmm) {
-               hmm_pfns_clear(pfns, start, end);
+               hmm_pfns_clear(range, range->pfns, range->start, range->end);
                return -ENOMEM;
        }
        /* Caller must have registered a mirror using hmm_mirror_register() */
        if (!hmm->mmu_notifier.ops)
                return -EINVAL;
 
+       /* FIXME support hugetlb fs */
+       if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL)) {
+               hmm_pfns_special(range);
+               return -EINVAL;
+       }
+
+       if (!(vma->vm_flags & VM_READ)) {
+               /*
+                * If vma do not allow read access, then assume that it does
+                * not allow write access, either. Architecture that allow
+                * write without read access are not supported by HMM, because
+                * operations such has atomic access would not work.
+                */
+               hmm_pfns_clear(range, range->pfns, range->start, range->end);
+               return -EPERM;
+       }
+
        /* Initialize range to track CPU page table update */
-       range->start = start;
-       range->pfns = pfns;
-       range->end = end;
        spin_lock(&hmm->lock);
        range->valid = true;
        list_add_rcu(&range->list, &hmm->ranges);
        spin_unlock(&hmm->lock);
 
-       /* FIXME support hugetlb fs */
-       if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL)) {
-               hmm_pfns_special(pfns, start, end);
-               return 0;
-       }
-
        hmm_vma_walk.fault = true;
-       hmm_vma_walk.write = write;
        hmm_vma_walk.block = block;
        hmm_vma_walk.range = range;
        mm_walk.private = &hmm_vma_walk;
@@ -717,7 +895,7 @@ int hmm_vma_fault(struct vm_area_struct *vma,
        mm_walk.pte_hole = hmm_vma_walk_hole;
 
        do {
-               ret = walk_page_range(start, end, &mm_walk);
+               ret = walk_page_range(start, range->end, &mm_walk);
                start = hmm_vma_walk.last;
        } while (ret == -EAGAIN);
 
@@ -725,8 +903,9 @@ int hmm_vma_fault(struct vm_area_struct *vma,
                unsigned long i;
 
                i = (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
-               hmm_pfns_clear(&pfns[i], hmm_vma_walk.last, end);
-               hmm_vma_range_done(vma, range);
+               hmm_pfns_clear(range, &range->pfns[i], hmm_vma_walk.last,
+                              range->end);
+               hmm_vma_range_done(range);
        }
        return ret;
 }
@@ -845,13 +1024,6 @@ static void hmm_devmem_release(struct device *dev, void *data)
        hmm_devmem_radix_release(resource);
 }
 
-static struct hmm_devmem *hmm_devmem_find(resource_size_t phys)
-{
-       WARN_ON_ONCE(!rcu_read_lock_held());
-
-       return radix_tree_lookup(&hmm_devmem_radix, phys >> PA_SECTION_SHIFT);
-}
-
 static int hmm_devmem_pages_create(struct hmm_devmem *devmem)
 {
        resource_size_t key, align_start, align_size, align_end;
@@ -892,9 +1064,8 @@ static int hmm_devmem_pages_create(struct hmm_devmem *devmem)
        for (key = align_start; key <= align_end; key += PA_SECTION_SIZE) {
                struct hmm_devmem *dup;
 
-               rcu_read_lock();
-               dup = hmm_devmem_find(key);
-               rcu_read_unlock();
+               dup = radix_tree_lookup(&hmm_devmem_radix,
+                                       key >> PA_SECTION_SHIFT);
                if (dup) {
                        dev_err(device, "%s: collides with mapping for %s\n",
                                __func__, dev_name(dup->device));