Merge branch 'for-4.18/dax' into libnvdimm-for-next
[sfrench/cifs-2.6.git] / mm / gup.c
index 541904a7c60fd596b1f5bc432fc7932bc005fadf..3d8472d48a0b88366e6e60e3c17fee1fe0dae7d3 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -1459,32 +1459,48 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
        return 1;
 }
 
-static int __gup_device_huge_pmd(pmd_t pmd, unsigned long addr,
+static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
                unsigned long end, struct page **pages, int *nr)
 {
        unsigned long fault_pfn;
+       int nr_start = *nr;
+
+       fault_pfn = pmd_pfn(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
+       if (!__gup_device_huge(fault_pfn, addr, end, pages, nr))
+               return 0;
 
-       fault_pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
-       return __gup_device_huge(fault_pfn, addr, end, pages, nr);
+       if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) {
+               undo_dev_pagemap(nr, nr_start, pages);
+               return 0;
+       }
+       return 1;
 }
 
-static int __gup_device_huge_pud(pud_t pud, unsigned long addr,
+static int __gup_device_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
                unsigned long end, struct page **pages, int *nr)
 {
        unsigned long fault_pfn;
+       int nr_start = *nr;
+
+       fault_pfn = pud_pfn(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
+       if (!__gup_device_huge(fault_pfn, addr, end, pages, nr))
+               return 0;
 
-       fault_pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
-       return __gup_device_huge(fault_pfn, addr, end, pages, nr);
+       if (unlikely(pud_val(orig) != pud_val(*pudp))) {
+               undo_dev_pagemap(nr, nr_start, pages);
+               return 0;
+       }
+       return 1;
 }
 #else
-static int __gup_device_huge_pmd(pmd_t pmd, unsigned long addr,
+static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
                unsigned long end, struct page **pages, int *nr)
 {
        BUILD_BUG();
        return 0;
 }
 
-static int __gup_device_huge_pud(pud_t pud, unsigned long addr,
+static int __gup_device_huge_pud(pud_t pud, pud_t *pudp, unsigned long addr,
                unsigned long end, struct page **pages, int *nr)
 {
        BUILD_BUG();
@@ -1502,7 +1518,7 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
                return 0;
 
        if (pmd_devmap(orig))
-               return __gup_device_huge_pmd(orig, addr, end, pages, nr);
+               return __gup_device_huge_pmd(orig, pmdp, addr, end, pages, nr);
 
        refs = 0;
        page = pmd_page(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
@@ -1540,7 +1556,7 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
                return 0;
 
        if (pud_devmap(orig))
-               return __gup_device_huge_pud(orig, addr, end, pages, nr);
+               return __gup_device_huge_pud(orig, pudp, addr, end, pages, nr);
 
        refs = 0;
        page = pud_page(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);