Merge tag 'for-4.20-rc1-tag' of git://git.kernel.org/pub/scm/linux/kernel/git/kdave...
[sfrench/cifs-2.6.git] / fs / dax.c
index 0fb270f0a0ef68f3264f9c21d108a98897a3cc0d..616e36ea6aaab6baf9fd3210aaa425c2398e491b 100644 (file)
--- a/fs/dax.c
+++ b/fs/dax.c
 #define CREATE_TRACE_POINTS
 #include <trace/events/fs_dax.h>
 
+static inline unsigned int pe_order(enum page_entry_size pe_size)
+{
+       if (pe_size == PE_SIZE_PTE)
+               return PAGE_SHIFT - PAGE_SHIFT;
+       if (pe_size == PE_SIZE_PMD)
+               return PMD_SHIFT - PAGE_SHIFT;
+       if (pe_size == PE_SIZE_PUD)
+               return PUD_SHIFT - PAGE_SHIFT;
+       return ~0;
+}
+
 /* We choose 4096 entries - same as per-zone page wait tables */
 #define DAX_WAIT_TABLE_BITS 12
 #define DAX_WAIT_TABLE_ENTRIES (1 << DAX_WAIT_TABLE_BITS)
@@ -46,6 +57,9 @@
 #define PG_PMD_COLOUR  ((PMD_SIZE >> PAGE_SHIFT) - 1)
 #define PG_PMD_NR      (PMD_SIZE >> PAGE_SHIFT)
 
+/* The order of a PMD entry */
+#define PMD_ORDER      (PMD_SHIFT - PAGE_SHIFT)
+
 static wait_queue_head_t wait_table[DAX_WAIT_TABLE_ENTRIES];
 
 static int __init init_dax_wait_table(void)
@@ -59,63 +73,74 @@ static int __init init_dax_wait_table(void)
 fs_initcall(init_dax_wait_table);
 
 /*
- * We use lowest available bit in exceptional entry for locking, one bit for
- * the entry size (PMD) and two more to tell us if the entry is a zero page or
- * an empty entry that is just used for locking.  In total four special bits.
+ * DAX pagecache entries use XArray value entries so they can't be mistaken
+ * for pages.  We use one bit for locking, one bit for the entry size (PMD)
+ * and two more to tell us if the entry is a zero page or an empty entry that
+ * is just used for locking.  In total four special bits.
  *
  * If the PMD bit isn't set the entry has size PAGE_SIZE, and if the ZERO_PAGE
  * and EMPTY bits aren't set the entry is a normal DAX entry with a filesystem
  * block allocation.
  */
-#define RADIX_DAX_SHIFT                (RADIX_TREE_EXCEPTIONAL_SHIFT + 4)
-#define RADIX_DAX_ENTRY_LOCK   (1 << RADIX_TREE_EXCEPTIONAL_SHIFT)
-#define RADIX_DAX_PMD          (1 << (RADIX_TREE_EXCEPTIONAL_SHIFT + 1))
-#define RADIX_DAX_ZERO_PAGE    (1 << (RADIX_TREE_EXCEPTIONAL_SHIFT + 2))
-#define RADIX_DAX_EMPTY                (1 << (RADIX_TREE_EXCEPTIONAL_SHIFT + 3))
+#define DAX_SHIFT      (4)
+#define DAX_LOCKED     (1UL << 0)
+#define DAX_PMD                (1UL << 1)
+#define DAX_ZERO_PAGE  (1UL << 2)
+#define DAX_EMPTY      (1UL << 3)
 
-static unsigned long dax_radix_pfn(void *entry)
+static unsigned long dax_to_pfn(void *entry)
 {
-       return (unsigned long)entry >> RADIX_DAX_SHIFT;
+       return xa_to_value(entry) >> DAX_SHIFT;
 }
 
-static void *dax_radix_locked_entry(unsigned long pfn, unsigned long flags)
+static void *dax_make_entry(pfn_t pfn, unsigned long flags)
 {
-       return (void *)(RADIX_TREE_EXCEPTIONAL_ENTRY | flags |
-                       (pfn << RADIX_DAX_SHIFT) | RADIX_DAX_ENTRY_LOCK);
+       return xa_mk_value(flags | (pfn_t_to_pfn(pfn) << DAX_SHIFT));
 }
 
-static unsigned int dax_radix_order(void *entry)
+static void *dax_make_page_entry(struct page *page)
 {
-       if ((unsigned long)entry & RADIX_DAX_PMD)
-               return PMD_SHIFT - PAGE_SHIFT;
+       pfn_t pfn = page_to_pfn_t(page);
+       return dax_make_entry(pfn, PageHead(page) ? DAX_PMD : 0);
+}
+
+static bool dax_is_locked(void *entry)
+{
+       return xa_to_value(entry) & DAX_LOCKED;
+}
+
+static unsigned int dax_entry_order(void *entry)
+{
+       if (xa_to_value(entry) & DAX_PMD)
+               return PMD_ORDER;
        return 0;
 }
 
 static int dax_is_pmd_entry(void *entry)
 {
-       return (unsigned long)entry & RADIX_DAX_PMD;
+       return xa_to_value(entry) & DAX_PMD;
 }
 
 static int dax_is_pte_entry(void *entry)
 {
-       return !((unsigned long)entry & RADIX_DAX_PMD);
+       return !(xa_to_value(entry) & DAX_PMD);
 }
 
 static int dax_is_zero_entry(void *entry)
 {
-       return (unsigned long)entry & RADIX_DAX_ZERO_PAGE;
+       return xa_to_value(entry) & DAX_ZERO_PAGE;
 }
 
 static int dax_is_empty_entry(void *entry)
 {
-       return (unsigned long)entry & RADIX_DAX_EMPTY;
+       return xa_to_value(entry) & DAX_EMPTY;
 }
 
 /*
- * DAX radix tree locking
+ * DAX page cache entry locking
  */
 struct exceptional_entry_key {
-       struct address_space *mapping;
+       struct xarray *xa;
        pgoff_t entry_start;
 };
 
@@ -124,10 +149,11 @@ struct wait_exceptional_entry_queue {
        struct exceptional_entry_key key;
 };
 
-static wait_queue_head_t *dax_entry_waitqueue(struct address_space *mapping,
-               pgoff_t index, void *entry, struct exceptional_entry_key *key)
+static wait_queue_head_t *dax_entry_waitqueue(struct xa_state *xas,
+               void *entry, struct exceptional_entry_key *key)
 {
        unsigned long hash;
+       unsigned long index = xas->xa_index;
 
        /*
         * If 'entry' is a PMD, align the 'index' that we use for the wait
@@ -136,22 +162,21 @@ static wait_queue_head_t *dax_entry_waitqueue(struct address_space *mapping,
         */
        if (dax_is_pmd_entry(entry))
                index &= ~PG_PMD_COLOUR;
-
-       key->mapping = mapping;
+       key->xa = xas->xa;
        key->entry_start = index;
 
-       hash = hash_long((unsigned long)mapping ^ index, DAX_WAIT_TABLE_BITS);
+       hash = hash_long((unsigned long)xas->xa ^ index, DAX_WAIT_TABLE_BITS);
        return wait_table + hash;
 }
 
-static int wake_exceptional_entry_func(wait_queue_entry_t *wait, unsigned int mode,
-                                      int sync, void *keyp)
+static int wake_exceptional_entry_func(wait_queue_entry_t *wait,
+               unsigned int mode, int sync, void *keyp)
 {
        struct exceptional_entry_key *key = keyp;
        struct wait_exceptional_entry_queue *ewait =
                container_of(wait, struct wait_exceptional_entry_queue, wait);
 
-       if (key->mapping != ewait->key.mapping ||
+       if (key->xa != ewait->key.xa ||
            key->entry_start != ewait->key.entry_start)
                return 0;
        return autoremove_wake_function(wait, mode, sync, NULL);
@@ -162,13 +187,12 @@ static int wake_exceptional_entry_func(wait_queue_entry_t *wait, unsigned int mo
  * The important information it's conveying is whether the entry at
  * this index used to be a PMD entry.
  */
-static void dax_wake_mapping_entry_waiter(struct address_space *mapping,
-               pgoff_t index, void *entry, bool wake_all)
+static void dax_wake_entry(struct xa_state *xas, void *entry, bool wake_all)
 {
        struct exceptional_entry_key key;
        wait_queue_head_t *wq;
 
-       wq = dax_entry_waitqueue(mapping, index, entry, &key);
+       wq = dax_entry_waitqueue(xas, entry, &key);
 
        /*
         * Checking for locked entry and prepare_to_wait_exclusive() happens
@@ -181,55 +205,16 @@ static void dax_wake_mapping_entry_waiter(struct address_space *mapping,
 }
 
 /*
- * Check whether the given slot is locked.  Must be called with the i_pages
- * lock held.
- */
-static inline int slot_locked(struct address_space *mapping, void **slot)
-{
-       unsigned long entry = (unsigned long)
-               radix_tree_deref_slot_protected(slot, &mapping->i_pages.xa_lock);
-       return entry & RADIX_DAX_ENTRY_LOCK;
-}
-
-/*
- * Mark the given slot as locked.  Must be called with the i_pages lock held.
- */
-static inline void *lock_slot(struct address_space *mapping, void **slot)
-{
-       unsigned long entry = (unsigned long)
-               radix_tree_deref_slot_protected(slot, &mapping->i_pages.xa_lock);
-
-       entry |= RADIX_DAX_ENTRY_LOCK;
-       radix_tree_replace_slot(&mapping->i_pages, slot, (void *)entry);
-       return (void *)entry;
-}
-
-/*
- * Mark the given slot as unlocked.  Must be called with the i_pages lock held.
- */
-static inline void *unlock_slot(struct address_space *mapping, void **slot)
-{
-       unsigned long entry = (unsigned long)
-               radix_tree_deref_slot_protected(slot, &mapping->i_pages.xa_lock);
-
-       entry &= ~(unsigned long)RADIX_DAX_ENTRY_LOCK;
-       radix_tree_replace_slot(&mapping->i_pages, slot, (void *)entry);
-       return (void *)entry;
-}
-
-/*
- * Lookup entry in radix tree, wait for it to become unlocked if it is
- * exceptional entry and return it. The caller must call
- * put_unlocked_mapping_entry() when he decided not to lock the entry or
- * put_locked_mapping_entry() when he locked the entry and now wants to
- * unlock it.
+ * Look up entry in page cache, wait for it to become unlocked if it
+ * is a DAX entry and return it.  The caller must subsequently call
+ * put_unlocked_entry() if it did not lock the entry or dax_unlock_entry()
+ * if it did.
  *
  * Must be called with the i_pages lock held.
  */
-static void *__get_unlocked_mapping_entry(struct address_space *mapping,
-               pgoff_t index, void ***slotp, bool (*wait_fn)(void))
+static void *get_unlocked_entry(struct xa_state *xas)
 {
-       void *entry, **slot;
+       void *entry;
        struct wait_exceptional_entry_queue ewait;
        wait_queue_head_t *wq;
 
@@ -237,80 +222,54 @@ static void *__get_unlocked_mapping_entry(struct address_space *mapping,
        ewait.wait.func = wake_exceptional_entry_func;
 
        for (;;) {
-               bool revalidate;
-
-               entry = __radix_tree_lookup(&mapping->i_pages, index, NULL,
-                                         &slot);
-               if (!entry ||
-                   WARN_ON_ONCE(!radix_tree_exceptional_entry(entry)) ||
-                   !slot_locked(mapping, slot)) {
-                       if (slotp)
-                               *slotp = slot;
+               entry = xas_load(xas);
+               if (!entry || xa_is_internal(entry) ||
+                               WARN_ON_ONCE(!xa_is_value(entry)) ||
+                               !dax_is_locked(entry))
                        return entry;
-               }
 
-               wq = dax_entry_waitqueue(mapping, index, entry, &ewait.key);
+               wq = dax_entry_waitqueue(xas, entry, &ewait.key);
                prepare_to_wait_exclusive(wq, &ewait.wait,
                                          TASK_UNINTERRUPTIBLE);
-               xa_unlock_irq(&mapping->i_pages);
-               revalidate = wait_fn();
+               xas_unlock_irq(xas);
+               xas_reset(xas);
+               schedule();
                finish_wait(wq, &ewait.wait);
-               xa_lock_irq(&mapping->i_pages);
-               if (revalidate)
-                       return ERR_PTR(-EAGAIN);
+               xas_lock_irq(xas);
        }
 }
 
-static bool entry_wait(void)
-{
-       schedule();
-       /*
-        * Never return an ERR_PTR() from
-        * __get_unlocked_mapping_entry(), just keep looping.
-        */
-       return false;
-}
-
-static void *get_unlocked_mapping_entry(struct address_space *mapping,
-               pgoff_t index, void ***slotp)
+static void put_unlocked_entry(struct xa_state *xas, void *entry)
 {
-       return __get_unlocked_mapping_entry(mapping, index, slotp, entry_wait);
-}
-
-static void unlock_mapping_entry(struct address_space *mapping, pgoff_t index)
-{
-       void *entry, **slot;
-
-       xa_lock_irq(&mapping->i_pages);
-       entry = __radix_tree_lookup(&mapping->i_pages, index, NULL, &slot);
-       if (WARN_ON_ONCE(!entry || !radix_tree_exceptional_entry(entry) ||
-                        !slot_locked(mapping, slot))) {
-               xa_unlock_irq(&mapping->i_pages);
-               return;
-       }
-       unlock_slot(mapping, slot);
-       xa_unlock_irq(&mapping->i_pages);
-       dax_wake_mapping_entry_waiter(mapping, index, entry, false);
+       /* If we were the only waiter woken, wake the next one */
+       if (entry)
+               dax_wake_entry(xas, entry, false);
 }
 
-static void put_locked_mapping_entry(struct address_space *mapping,
-               pgoff_t index)
+/*
+ * We used the xa_state to get the entry, but then we locked the entry and
+ * dropped the xa_lock, so we know the xa_state is stale and must be reset
+ * before use.
+ */
+static void dax_unlock_entry(struct xa_state *xas, void *entry)
 {
-       unlock_mapping_entry(mapping, index);
+       void *old;
+
+       xas_reset(xas);
+       xas_lock_irq(xas);
+       old = xas_store(xas, entry);
+       xas_unlock_irq(xas);
+       BUG_ON(!dax_is_locked(old));
+       dax_wake_entry(xas, entry, false);
 }
 
 /*
- * Called when we are done with radix tree entry we looked up via
- * get_unlocked_mapping_entry() and which we didn't lock in the end.
+ * Return: The entry stored at this location before it was locked.
  */
-static void put_unlocked_mapping_entry(struct address_space *mapping,
-                                      pgoff_t index, void *entry)
+static void *dax_lock_entry(struct xa_state *xas, void *entry)
 {
-       if (!entry)
-               return;
-
-       /* We have to wake up next waiter for the radix tree entry lock */
-       dax_wake_mapping_entry_waiter(mapping, index, entry, false);
+       unsigned long v = xa_to_value(entry);
+       return xas_store(xas, xa_mk_value(v | DAX_LOCKED));
 }
 
 static unsigned long dax_entry_size(void *entry)
@@ -325,9 +284,9 @@ static unsigned long dax_entry_size(void *entry)
                return PAGE_SIZE;
 }
 
-static unsigned long dax_radix_end_pfn(void *entry)
+static unsigned long dax_end_pfn(void *entry)
 {
-       return dax_radix_pfn(entry) + dax_entry_size(entry) / PAGE_SIZE;
+       return dax_to_pfn(entry) + dax_entry_size(entry) / PAGE_SIZE;
 }
 
 /*
@@ -335,8 +294,8 @@ static unsigned long dax_radix_end_pfn(void *entry)
  * 'empty' and 'zero' entries.
  */
 #define for_each_mapped_pfn(entry, pfn) \
-       for (pfn = dax_radix_pfn(entry); \
-                       pfn < dax_radix_end_pfn(entry); pfn++)
+       for (pfn = dax_to_pfn(entry); \
+                       pfn < dax_end_pfn(entry); pfn++)
 
 /*
  * TODO: for reflink+dax we need a way to associate a single page with
@@ -393,33 +352,16 @@ static struct page *dax_busy_page(void *entry)
        return NULL;
 }
 
-static bool entry_wait_revalidate(void)
-{
-       rcu_read_unlock();
-       schedule();
-       rcu_read_lock();
-
-       /*
-        * Tell __get_unlocked_mapping_entry() to take a break, we need
-        * to revalidate page->mapping after dropping locks
-        */
-       return true;
-}
-
 bool dax_lock_mapping_entry(struct page *page)
 {
-       pgoff_t index;
-       struct inode *inode;
-       bool did_lock = false;
-       void *entry = NULL, **slot;
-       struct address_space *mapping;
+       XA_STATE(xas, NULL, 0);
+       void *entry;
 
-       rcu_read_lock();
        for (;;) {
-               mapping = READ_ONCE(page->mapping);
+               struct address_space *mapping = READ_ONCE(page->mapping);
 
                if (!dax_mapping(mapping))
-                       break;
+                       return false;
 
                /*
                 * In the device-dax case there's no need to lock, a
@@ -428,98 +370,94 @@ bool dax_lock_mapping_entry(struct page *page)
                 * otherwise we would not have a valid pfn_to_page()
                 * translation.
                 */
-               inode = mapping->host;
-               if (S_ISCHR(inode->i_mode)) {
-                       did_lock = true;
-                       break;
-               }
+               if (S_ISCHR(mapping->host->i_mode))
+                       return true;
 
-               xa_lock_irq(&mapping->i_pages);
+               xas.xa = &mapping->i_pages;
+               xas_lock_irq(&xas);
                if (mapping != page->mapping) {
-                       xa_unlock_irq(&mapping->i_pages);
+                       xas_unlock_irq(&xas);
                        continue;
                }
-               index = page->index;
-
-               entry = __get_unlocked_mapping_entry(mapping, index, &slot,
-                               entry_wait_revalidate);
-               if (!entry) {
-                       xa_unlock_irq(&mapping->i_pages);
-                       break;
-               } else if (IS_ERR(entry)) {
-                       xa_unlock_irq(&mapping->i_pages);
-                       WARN_ON_ONCE(PTR_ERR(entry) != -EAGAIN);
-                       continue;
+               xas_set(&xas, page->index);
+               entry = xas_load(&xas);
+               if (dax_is_locked(entry)) {
+                       entry = get_unlocked_entry(&xas);
+                       /* Did the page move while we slept? */
+                       if (dax_to_pfn(entry) != page_to_pfn(page)) {
+                               xas_unlock_irq(&xas);
+                               continue;
+                       }
                }
-               lock_slot(mapping, slot);
-               did_lock = true;
-               xa_unlock_irq(&mapping->i_pages);
-               break;
+               dax_lock_entry(&xas, entry);
+               xas_unlock_irq(&xas);
+               return true;
        }
-       rcu_read_unlock();
-
-       return did_lock;
 }
 
 void dax_unlock_mapping_entry(struct page *page)
 {
        struct address_space *mapping = page->mapping;
-       struct inode *inode = mapping->host;
+       XA_STATE(xas, &mapping->i_pages, page->index);
 
-       if (S_ISCHR(inode->i_mode))
+       if (S_ISCHR(mapping->host->i_mode))
                return;
 
-       unlock_mapping_entry(mapping, page->index);
+       dax_unlock_entry(&xas, dax_make_page_entry(page));
 }
 
 /*
- * Find radix tree entry at given index. If it points to an exceptional entry,
- * return it with the radix tree entry locked. If the radix tree doesn't
- * contain given index, create an empty exceptional entry for the index and
- * return with it locked.
+ * Find page cache entry at given index. If it is a DAX entry, return it
+ * with the entry locked. If the page cache doesn't contain an entry at
+ * that index, add a locked empty entry.
  *
- * When requesting an entry with size RADIX_DAX_PMD, grab_mapping_entry() will
- * either return that locked entry or will return an error.  This error will
- * happen if there are any 4k entries within the 2MiB range that we are
- * requesting.
+ * When requesting an entry with size DAX_PMD, grab_mapping_entry() will
+ * either return that locked entry or will return VM_FAULT_FALLBACK.
+ * This will happen if there are any PTE entries within the PMD range
+ * that we are requesting.
  *
- * We always favor 4k entries over 2MiB entries. There isn't a flow where we
- * evict 4k entries in order to 'upgrade' them to a 2MiB entry.  A 2MiB
- * insertion will fail if it finds any 4k entries already in the tree, and a
- * 4k insertion will cause an existing 2MiB entry to be unmapped and
- * downgraded to 4k entries.  This happens for both 2MiB huge zero pages as
- * well as 2MiB empty entries.
+ * We always favor PTE entries over PMD entries. There isn't a flow where we
+ * evict PTE entries in order to 'upgrade' them to a PMD entry.  A PMD
+ * insertion will fail if it finds any PTE entries already in the tree, and a
+ * PTE insertion will cause an existing PMD entry to be unmapped and
+ * downgraded to PTE entries.  This happens for both PMD zero pages as
+ * well as PMD empty entries.
  *
- * The exception to this downgrade path is for 2MiB DAX PMD entries that have
- * real storage backing them.  We will leave these real 2MiB DAX entries in
- * the tree, and PTE writes will simply dirty the entire 2MiB DAX entry.
+ * The exception to this downgrade path is for PMD entries that have
+ * real storage backing them.  We will leave these real PMD entries in
+ * the tree, and PTE writes will simply dirty the entire PMD entry.
  *
  * Note: Unlike filemap_fault() we don't honor FAULT_FLAG_RETRY flags. For
  * persistent memory the benefit is doubtful. We can add that later if we can
  * show it helps.
+ *
+ * On error, this function does not return an ERR_PTR.  Instead it returns
+ * a VM_FAULT code, encoded as an xarray internal entry.  The ERR_PTR values
+ * overlap with xarray value entries.
  */
-static void *grab_mapping_entry(struct address_space *mapping, pgoff_t index,
-               unsigned long size_flag)
+static void *grab_mapping_entry(struct xa_state *xas,
+               struct address_space *mapping, unsigned long size_flag)
 {
-       bool pmd_downgrade = false; /* splitting 2MiB entry into 4k entries? */
-       void *entry, **slot;
-
-restart:
-       xa_lock_irq(&mapping->i_pages);
-       entry = get_unlocked_mapping_entry(mapping, index, &slot);
+       unsigned long index = xas->xa_index;
+       bool pmd_downgrade = false; /* splitting PMD entry into PTE entries? */
+       void *entry;
 
-       if (WARN_ON_ONCE(entry && !radix_tree_exceptional_entry(entry))) {
-               entry = ERR_PTR(-EIO);
-               goto out_unlock;
-       }
+retry:
+       xas_lock_irq(xas);
+       entry = get_unlocked_entry(xas);
+       if (xa_is_internal(entry))
+               goto fallback;
 
        if (entry) {
-               if (size_flag & RADIX_DAX_PMD) {
+               if (WARN_ON_ONCE(!xa_is_value(entry))) {
+                       xas_set_err(xas, EIO);
+                       goto out_unlock;
+               }
+
+               if (size_flag & DAX_PMD) {
                        if (dax_is_pte_entry(entry)) {
-                               put_unlocked_mapping_entry(mapping, index,
-                                               entry);
-                               entry = ERR_PTR(-EEXIST);
-                               goto out_unlock;
+                               put_unlocked_entry(xas, entry);
+                               goto fallback;
                        }
                } else { /* trying to grab a PTE entry */
                        if (dax_is_pmd_entry(entry) &&
@@ -530,87 +468,57 @@ restart:
                }
        }
 
-       /* No entry for given index? Make sure radix tree is big enough. */
-       if (!entry || pmd_downgrade) {
-               int err;
-
-               if (pmd_downgrade) {
-                       /*
-                        * Make sure 'entry' remains valid while we drop
-                        * the i_pages lock.
-                        */
-                       entry = lock_slot(mapping, slot);
-               }
+       if (pmd_downgrade) {
+               /*
+                * Make sure 'entry' remains valid while we drop
+                * the i_pages lock.
+                */
+               dax_lock_entry(xas, entry);
 
-               xa_unlock_irq(&mapping->i_pages);
                /*
                 * Besides huge zero pages the only other thing that gets
                 * downgraded are empty entries which don't need to be
                 * unmapped.
                 */
-               if (pmd_downgrade && dax_is_zero_entry(entry))
-                       unmap_mapping_pages(mapping, index & ~PG_PMD_COLOUR,
-                                                       PG_PMD_NR, false);
-
-               err = radix_tree_preload(
-                               mapping_gfp_mask(mapping) & ~__GFP_HIGHMEM);
-               if (err) {
-                       if (pmd_downgrade)
-                               put_locked_mapping_entry(mapping, index);
-                       return ERR_PTR(err);
-               }
-               xa_lock_irq(&mapping->i_pages);
-
-               if (!entry) {
-                       /*
-                        * We needed to drop the i_pages lock while calling
-                        * radix_tree_preload() and we didn't have an entry to
-                        * lock.  See if another thread inserted an entry at
-                        * our index during this time.
-                        */
-                       entry = __radix_tree_lookup(&mapping->i_pages, index,
-                                       NULL, &slot);
-                       if (entry) {
-                               radix_tree_preload_end();
-                               xa_unlock_irq(&mapping->i_pages);
-                               goto restart;
-                       }
+               if (dax_is_zero_entry(entry)) {
+                       xas_unlock_irq(xas);
+                       unmap_mapping_pages(mapping,
+                                       xas->xa_index & ~PG_PMD_COLOUR,
+                                       PG_PMD_NR, false);
+                       xas_reset(xas);
+                       xas_lock_irq(xas);
                }
 
-               if (pmd_downgrade) {
-                       dax_disassociate_entry(entry, mapping, false);
-                       radix_tree_delete(&mapping->i_pages, index);
-                       mapping->nrexceptional--;
-                       dax_wake_mapping_entry_waiter(mapping, index, entry,
-                                       true);
-               }
+               dax_disassociate_entry(entry, mapping, false);
+               xas_store(xas, NULL);   /* undo the PMD join */
+               dax_wake_entry(xas, entry, true);
+               mapping->nrexceptional--;
+               entry = NULL;
+               xas_set(xas, index);
+       }
 
-               entry = dax_radix_locked_entry(0, size_flag | RADIX_DAX_EMPTY);
-
-               err = __radix_tree_insert(&mapping->i_pages, index,
-                               dax_radix_order(entry), entry);
-               radix_tree_preload_end();
-               if (err) {
-                       xa_unlock_irq(&mapping->i_pages);
-                       /*
-                        * Our insertion of a DAX entry failed, most likely
-                        * because we were inserting a PMD entry and it
-                        * collided with a PTE sized entry at a different
-                        * index in the PMD range.  We haven't inserted
-                        * anything into the radix tree and have no waiters to
-                        * wake.
-                        */
-                       return ERR_PTR(err);
-               }
-               /* Good, we have inserted empty locked entry into the tree. */
+       if (entry) {
+               dax_lock_entry(xas, entry);
+       } else {
+               entry = dax_make_entry(pfn_to_pfn_t(0), size_flag | DAX_EMPTY);
+               dax_lock_entry(xas, entry);
+               if (xas_error(xas))
+                       goto out_unlock;
                mapping->nrexceptional++;
-               xa_unlock_irq(&mapping->i_pages);
-               return entry;
        }
-       entry = lock_slot(mapping, slot);
- out_unlock:
-       xa_unlock_irq(&mapping->i_pages);
+
+out_unlock:
+       xas_unlock_irq(xas);
+       if (xas_nomem(xas, mapping_gfp_mask(mapping) & ~__GFP_HIGHMEM))
+               goto retry;
+       if (xas->xa_node == XA_ERROR(-ENOMEM))
+               return xa_mk_internal(VM_FAULT_OOM);
+       if (xas_error(xas))
+               return xa_mk_internal(VM_FAULT_SIGBUS);
        return entry;
+fallback:
+       xas_unlock_irq(xas);
+       return xa_mk_internal(VM_FAULT_FALLBACK);
 }
 
 /**
@@ -630,11 +538,10 @@ restart:
  */
 struct page *dax_layout_busy_page(struct address_space *mapping)
 {
-       pgoff_t indices[PAGEVEC_SIZE];
+       XA_STATE(xas, &mapping->i_pages, 0);
+       void *entry;
+       unsigned int scanned = 0;
        struct page *page = NULL;
-       struct pagevec pvec;
-       pgoff_t index, end;
-       unsigned i;
 
        /*
         * In the 'limited' case get_user_pages() for dax is disabled.
@@ -645,13 +552,9 @@ struct page *dax_layout_busy_page(struct address_space *mapping)
        if (!dax_mapping(mapping) || !mapping_mapped(mapping))
                return NULL;
 
-       pagevec_init(&pvec);
-       index = 0;
-       end = -1;
-
        /*
         * If we race get_user_pages_fast() here either we'll see the
-        * elevated page count in the pagevec_lookup and wait, or
+        * elevated page count in the iteration and wait, or
         * get_user_pages_fast() will see that the page it took a reference
         * against is no longer mapped in the page tables and bail to the
         * get_user_pages() slow path.  The slow path is protected by
@@ -663,94 +566,68 @@ struct page *dax_layout_busy_page(struct address_space *mapping)
         */
        unmap_mapping_range(mapping, 0, 0, 1);
 
-       while (index < end && pagevec_lookup_entries(&pvec, mapping, index,
-                               min(end - index, (pgoff_t)PAGEVEC_SIZE),
-                               indices)) {
-               pgoff_t nr_pages = 1;
-
-               for (i = 0; i < pagevec_count(&pvec); i++) {
-                       struct page *pvec_ent = pvec.pages[i];
-                       void *entry;
-
-                       index = indices[i];
-                       if (index >= end)
-                               break;
-
-                       if (WARN_ON_ONCE(
-                            !radix_tree_exceptional_entry(pvec_ent)))
-                               continue;
-
-                       xa_lock_irq(&mapping->i_pages);
-                       entry = get_unlocked_mapping_entry(mapping, index, NULL);
-                       if (entry) {
-                               page = dax_busy_page(entry);
-                               /*
-                                * Account for multi-order entries at
-                                * the end of the pagevec.
-                                */
-                               if (i + 1 >= pagevec_count(&pvec))
-                                       nr_pages = 1UL << dax_radix_order(entry);
-                       }
-                       put_unlocked_mapping_entry(mapping, index, entry);
-                       xa_unlock_irq(&mapping->i_pages);
-                       if (page)
-                               break;
-               }
-
-               /*
-                * We don't expect normal struct page entries to exist in our
-                * tree, but we keep these pagevec calls so that this code is
-                * consistent with the common pattern for handling pagevecs
-                * throughout the kernel.
-                */
-               pagevec_remove_exceptionals(&pvec);
-               pagevec_release(&pvec);
-               index += nr_pages;
-
+       xas_lock_irq(&xas);
+       xas_for_each(&xas, entry, ULONG_MAX) {
+               if (WARN_ON_ONCE(!xa_is_value(entry)))
+                       continue;
+               if (unlikely(dax_is_locked(entry)))
+                       entry = get_unlocked_entry(&xas);
+               if (entry)
+                       page = dax_busy_page(entry);
+               put_unlocked_entry(&xas, entry);
                if (page)
                        break;
+               if (++scanned % XA_CHECK_SCHED)
+                       continue;
+
+               xas_pause(&xas);
+               xas_unlock_irq(&xas);
+               cond_resched();
+               xas_lock_irq(&xas);
        }
+       xas_unlock_irq(&xas);
        return page;
 }
 EXPORT_SYMBOL_GPL(dax_layout_busy_page);
 
-static int __dax_invalidate_mapping_entry(struct address_space *mapping,
+static int __dax_invalidate_entry(struct address_space *mapping,
                                          pgoff_t index, bool trunc)
 {
+       XA_STATE(xas, &mapping->i_pages, index);
        int ret = 0;
        void *entry;
-       struct radix_tree_root *pages = &mapping->i_pages;
 
-       xa_lock_irq(pages);
-       entry = get_unlocked_mapping_entry(mapping, index, NULL);
-       if (!entry || WARN_ON_ONCE(!radix_tree_exceptional_entry(entry)))
+       xas_lock_irq(&xas);
+       entry = get_unlocked_entry(&xas);
+       if (!entry || WARN_ON_ONCE(!xa_is_value(entry)))
                goto out;
        if (!trunc &&
-           (radix_tree_tag_get(pages, index, PAGECACHE_TAG_DIRTY) ||
-            radix_tree_tag_get(pages, index, PAGECACHE_TAG_TOWRITE)))
+           (xas_get_mark(&xas, PAGECACHE_TAG_DIRTY) ||
+            xas_get_mark(&xas, PAGECACHE_TAG_TOWRITE)))
                goto out;
        dax_disassociate_entry(entry, mapping, trunc);
-       radix_tree_delete(pages, index);
+       xas_store(&xas, NULL);
        mapping->nrexceptional--;
        ret = 1;
 out:
-       put_unlocked_mapping_entry(mapping, index, entry);
-       xa_unlock_irq(pages);
+       put_unlocked_entry(&xas, entry);
+       xas_unlock_irq(&xas);
        return ret;
 }
+
 /*
- * Delete exceptional DAX entry at @index from @mapping. Wait for radix tree
- * entry to get unlocked before deleting it.
+ * Delete DAX entry at @index from @mapping.  Wait for it
+ * to be unlocked before deleting it.
  */
 int dax_delete_mapping_entry(struct address_space *mapping, pgoff_t index)
 {
-       int ret = __dax_invalidate_mapping_entry(mapping, index, true);
+       int ret = __dax_invalidate_entry(mapping, index, true);
 
        /*
         * This gets called from truncate / punch_hole path. As such, the caller
         * must hold locks protecting against concurrent modifications of the
-        * radix tree (usually fs-private i_mmap_sem for writing). Since the
-        * caller has seen exceptional entry for this index, we better find it
+        * page cache (usually fs-private i_mmap_sem for writing). Since the
+        * caller has seen a DAX entry for this index, we better find it
         * at that index as well...
         */
        WARN_ON_ONCE(!ret);
@@ -758,12 +635,12 @@ int dax_delete_mapping_entry(struct address_space *mapping, pgoff_t index)
 }
 
 /*
- * Invalidate exceptional DAX entry if it is clean.
+ * Invalidate DAX entry if it is clean.
  */
 int dax_invalidate_mapping_entry_sync(struct address_space *mapping,
                                      pgoff_t index)
 {
-       return __dax_invalidate_mapping_entry(mapping, index, false);
+       return __dax_invalidate_entry(mapping, index, false);
 }
 
 static int copy_user_dax(struct block_device *bdev, struct dax_device *dax_dev,
@@ -799,30 +676,27 @@ static int copy_user_dax(struct block_device *bdev, struct dax_device *dax_dev,
  * already in the tree, we will skip the insertion and just dirty the PMD as
  * appropriate.
  */
-static void *dax_insert_mapping_entry(struct address_space *mapping,
-                                     struct vm_fault *vmf,
-                                     void *entry, pfn_t pfn_t,
-                                     unsigned long flags, bool dirty)
+static void *dax_insert_entry(struct xa_state *xas,
+               struct address_space *mapping, struct vm_fault *vmf,
+               void *entry, pfn_t pfn, unsigned long flags, bool dirty)
 {
-       struct radix_tree_root *pages = &mapping->i_pages;
-       unsigned long pfn = pfn_t_to_pfn(pfn_t);
-       pgoff_t index = vmf->pgoff;
-       void *new_entry;
+       void *new_entry = dax_make_entry(pfn, flags);
 
        if (dirty)
                __mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
 
-       if (dax_is_zero_entry(entry) && !(flags & RADIX_DAX_ZERO_PAGE)) {
+       if (dax_is_zero_entry(entry) && !(flags & DAX_ZERO_PAGE)) {
+               unsigned long index = xas->xa_index;
                /* we are replacing a zero page with block mapping */
                if (dax_is_pmd_entry(entry))
                        unmap_mapping_pages(mapping, index & ~PG_PMD_COLOUR,
-                                                       PG_PMD_NR, false);
+                                       PG_PMD_NR, false);
                else /* pte entry */
-                       unmap_mapping_pages(mapping, vmf->pgoff, 1, false);
+                       unmap_mapping_pages(mapping, index, 1, false);
        }
 
-       xa_lock_irq(pages);
-       new_entry = dax_radix_locked_entry(pfn, flags);
+       xas_reset(xas);
+       xas_lock_irq(xas);
        if (dax_entry_size(entry) != dax_entry_size(new_entry)) {
                dax_disassociate_entry(entry, mapping, false);
                dax_associate_entry(new_entry, mapping, vmf->vma, vmf->address);
@@ -830,33 +704,30 @@ static void *dax_insert_mapping_entry(struct address_space *mapping,
 
        if (dax_is_zero_entry(entry) || dax_is_empty_entry(entry)) {
                /*
-                * Only swap our new entry into the radix tree if the current
+                * Only swap our new entry into the page cache if the current
                 * entry is a zero page or an empty entry.  If a normal PTE or
-                * PMD entry is already in the tree, we leave it alone.  This
+                * PMD entry is already in the cache, we leave it alone.  This
                 * means that if we are trying to insert a PTE and the
                 * existing entry is a PMD, we will just leave the PMD in the
                 * tree and dirty it if necessary.
                 */
-               struct radix_tree_node *node;
-               void **slot;
-               void *ret;
-
-               ret = __radix_tree_lookup(pages, index, &node, &slot);
-               WARN_ON_ONCE(ret != entry);
-               __radix_tree_replace(pages, node, slot,
-                                    new_entry, NULL);
+               void *old = dax_lock_entry(xas, new_entry);
+               WARN_ON_ONCE(old != xa_mk_value(xa_to_value(entry) |
+                                       DAX_LOCKED));
                entry = new_entry;
+       } else {
+               xas_load(xas);  /* Walk the xa_state */
        }
 
        if (dirty)
-               radix_tree_tag_set(pages, index, PAGECACHE_TAG_DIRTY);
+               xas_set_mark(xas, PAGECACHE_TAG_DIRTY);
 
-       xa_unlock_irq(pages);
+       xas_unlock_irq(xas);
        return entry;
 }
 
-static inline unsigned long
-pgoff_address(pgoff_t pgoff, struct vm_area_struct *vma)
+static inline
+unsigned long pgoff_address(pgoff_t pgoff, struct vm_area_struct *vma)
 {
        unsigned long address;
 
@@ -866,8 +737,8 @@ pgoff_address(pgoff_t pgoff, struct vm_area_struct *vma)
 }
 
 /* Walk all mappings of a given index of a file and writeprotect them */
-static void dax_mapping_entry_mkclean(struct address_space *mapping,
-                                     pgoff_t index, unsigned long pfn)
+static void dax_entry_mkclean(struct address_space *mapping, pgoff_t index,
+               unsigned long pfn)
 {
        struct vm_area_struct *vma;
        pte_t pte, *ptep = NULL;
@@ -937,11 +808,9 @@ unlock_pte:
        i_mmap_unlock_read(mapping);
 }
 
-static int dax_writeback_one(struct dax_device *dax_dev,
-               struct address_space *mapping, pgoff_t index, void *entry)
+static int dax_writeback_one(struct xa_state *xas, struct dax_device *dax_dev,
+               struct address_space *mapping, void *entry)
 {
-       struct radix_tree_root *pages = &mapping->i_pages;
-       void *entry2, **slot;
        unsigned long pfn;
        long ret = 0;
        size_t size;
@@ -950,32 +819,38 @@ static int dax_writeback_one(struct dax_device *dax_dev,
         * A page got tagged dirty in DAX mapping? Something is seriously
         * wrong.
         */
-       if (WARN_ON(!radix_tree_exceptional_entry(entry)))
+       if (WARN_ON(!xa_is_value(entry)))
                return -EIO;
 
-       xa_lock_irq(pages);
-       entry2 = get_unlocked_mapping_entry(mapping, index, &slot);
-       /* Entry got punched out / reallocated? */
-       if (!entry2 || WARN_ON_ONCE(!radix_tree_exceptional_entry(entry2)))
-               goto put_unlocked;
-       /*
-        * Entry got reallocated elsewhere? No need to writeback. We have to
-        * compare pfns as we must not bail out due to difference in lockbit
-        * or entry type.
-        */
-       if (dax_radix_pfn(entry2) != dax_radix_pfn(entry))
-               goto put_unlocked;
-       if (WARN_ON_ONCE(dax_is_empty_entry(entry) ||
-                               dax_is_zero_entry(entry))) {
-               ret = -EIO;
-               goto put_unlocked;
+       if (unlikely(dax_is_locked(entry))) {
+               void *old_entry = entry;
+
+               entry = get_unlocked_entry(xas);
+
+               /* Entry got punched out / reallocated? */
+               if (!entry || WARN_ON_ONCE(!xa_is_value(entry)))
+                       goto put_unlocked;
+               /*
+                * Entry got reallocated elsewhere? No need to writeback.
+                * We have to compare pfns as we must not bail out due to
+                * difference in lockbit or entry type.
+                */
+               if (dax_to_pfn(old_entry) != dax_to_pfn(entry))
+                       goto put_unlocked;
+               if (WARN_ON_ONCE(dax_is_empty_entry(entry) ||
+                                       dax_is_zero_entry(entry))) {
+                       ret = -EIO;
+                       goto put_unlocked;
+               }
+
+               /* Another fsync thread may have already done this entry */
+               if (!xas_get_mark(xas, PAGECACHE_TAG_TOWRITE))
+                       goto put_unlocked;
        }
 
-       /* Another fsync thread may have already written back this entry */
-       if (!radix_tree_tag_get(pages, index, PAGECACHE_TAG_TOWRITE))
-               goto put_unlocked;
        /* Lock the entry to serialize with page faults */
-       entry = lock_slot(mapping, slot);
+       dax_lock_entry(xas, entry);
+
        /*
         * We can clear the tag now but we have to be careful so that concurrent
         * dax_writeback_one() calls for the same index cannot finish before we
@@ -983,8 +858,8 @@ static int dax_writeback_one(struct dax_device *dax_dev,
         * at the entry only under the i_pages lock and once they do that
         * they will see the entry locked and wait for it to unlock.
         */
-       radix_tree_tag_clear(pages, index, PAGECACHE_TAG_TOWRITE);
-       xa_unlock_irq(pages);
+       xas_clear_mark(xas, PAGECACHE_TAG_TOWRITE);
+       xas_unlock_irq(xas);
 
        /*
         * Even if dax_writeback_mapping_range() was given a wbc->range_start
@@ -993,10 +868,10 @@ static int dax_writeback_one(struct dax_device *dax_dev,
         * This allows us to flush for PMD_SIZE and not have to worry about
         * partial PMD writebacks.
         */
-       pfn = dax_radix_pfn(entry);
-       size = PAGE_SIZE << dax_radix_order(entry);
+       pfn = dax_to_pfn(entry);
+       size = PAGE_SIZE << dax_entry_order(entry);
 
-       dax_mapping_entry_mkclean(mapping, index, pfn);
+       dax_entry_mkclean(mapping, xas->xa_index, pfn);
        dax_flush(dax_dev, page_address(pfn_to_page(pfn)), size);
        /*
         * After we have flushed the cache, we can clear the dirty tag. There
@@ -1004,16 +879,18 @@ static int dax_writeback_one(struct dax_device *dax_dev,
         * the pfn mappings are writeprotected and fault waits for mapping
         * entry lock.
         */
-       xa_lock_irq(pages);
-       radix_tree_tag_clear(pages, index, PAGECACHE_TAG_DIRTY);
-       xa_unlock_irq(pages);
-       trace_dax_writeback_one(mapping->host, index, size >> PAGE_SHIFT);
-       put_locked_mapping_entry(mapping, index);
+       xas_reset(xas);
+       xas_lock_irq(xas);
+       xas_store(xas, entry);
+       xas_clear_mark(xas, PAGECACHE_TAG_DIRTY);
+       dax_wake_entry(xas, entry, false);
+
+       trace_dax_writeback_one(mapping->host, xas->xa_index,
+                       size >> PAGE_SHIFT);
        return ret;
 
  put_unlocked:
-       put_unlocked_mapping_entry(mapping, index, entry2);
-       xa_unlock_irq(pages);
+       put_unlocked_entry(xas, entry);
        return ret;
 }
 
@@ -1025,13 +902,13 @@ static int dax_writeback_one(struct dax_device *dax_dev,
 int dax_writeback_mapping_range(struct address_space *mapping,
                struct block_device *bdev, struct writeback_control *wbc)
 {
+       XA_STATE(xas, &mapping->i_pages, wbc->range_start >> PAGE_SHIFT);
        struct inode *inode = mapping->host;
-       pgoff_t start_index, end_index;
-       pgoff_t indices[PAGEVEC_SIZE];
+       pgoff_t end_index = wbc->range_end >> PAGE_SHIFT;
        struct dax_device *dax_dev;
-       struct pagevec pvec;
-       bool done = false;
-       int i, ret = 0;
+       void *entry;
+       int ret = 0;
+       unsigned int scanned = 0;
 
        if (WARN_ON_ONCE(inode->i_blkbits != PAGE_SHIFT))
                return -EIO;
@@ -1043,41 +920,29 @@ int dax_writeback_mapping_range(struct address_space *mapping,
        if (!dax_dev)
                return -EIO;
 
-       start_index = wbc->range_start >> PAGE_SHIFT;
-       end_index = wbc->range_end >> PAGE_SHIFT;
-
-       trace_dax_writeback_range(inode, start_index, end_index);
+       trace_dax_writeback_range(inode, xas.xa_index, end_index);
 
-       tag_pages_for_writeback(mapping, start_index, end_index);
+       tag_pages_for_writeback(mapping, xas.xa_index, end_index);
 
-       pagevec_init(&pvec);
-       while (!done) {
-               pvec.nr = find_get_entries_tag(mapping, start_index,
-                               PAGECACHE_TAG_TOWRITE, PAGEVEC_SIZE,
-                               pvec.pages, indices);
-
-               if (pvec.nr == 0)
+       xas_lock_irq(&xas);
+       xas_for_each_marked(&xas, entry, end_index, PAGECACHE_TAG_TOWRITE) {
+               ret = dax_writeback_one(&xas, dax_dev, mapping, entry);
+               if (ret < 0) {
+                       mapping_set_error(mapping, ret);
                        break;
-
-               for (i = 0; i < pvec.nr; i++) {
-                       if (indices[i] > end_index) {
-                               done = true;
-                               break;
-                       }
-
-                       ret = dax_writeback_one(dax_dev, mapping, indices[i],
-                                       pvec.pages[i]);
-                       if (ret < 0) {
-                               mapping_set_error(mapping, ret);
-                               goto out;
-                       }
                }
-               start_index = indices[pvec.nr - 1] + 1;
+               if (++scanned % XA_CHECK_SCHED)
+                       continue;
+
+               xas_pause(&xas);
+               xas_unlock_irq(&xas);
+               cond_resched();
+               xas_lock_irq(&xas);
        }
-out:
+       xas_unlock_irq(&xas);
        put_dax(dax_dev);
-       trace_dax_writeback_range_done(inode, start_index, end_index);
-       return (ret < 0 ? ret : 0);
+       trace_dax_writeback_range_done(inode, xas.xa_index, end_index);
+       return ret;
 }
 EXPORT_SYMBOL_GPL(dax_writeback_mapping_range);
 
@@ -1125,16 +990,18 @@ out:
  * If this page is ever written to we will re-fault and change the mapping to
  * point to real DAX storage instead.
  */
-static vm_fault_t dax_load_hole(struct address_space *mapping, void *entry,
-                        struct vm_fault *vmf)
+static vm_fault_t dax_load_hole(struct xa_state *xas,
+               struct address_space *mapping, void **entry,
+               struct vm_fault *vmf)
 {
        struct inode *inode = mapping->host;
        unsigned long vaddr = vmf->address;
        pfn_t pfn = pfn_to_pfn_t(my_zero_pfn(vaddr));
        vm_fault_t ret;
 
-       dax_insert_mapping_entry(mapping, vmf, entry, pfn, RADIX_DAX_ZERO_PAGE,
-                       false);
+       *entry = dax_insert_entry(xas, mapping, vmf, *entry, pfn,
+                       DAX_ZERO_PAGE, false);
+
        ret = vmf_insert_mixed(vmf->vma, vaddr, pfn);
        trace_dax_load_hole(inode, vmf, ret);
        return ret;
@@ -1342,6 +1209,7 @@ static vm_fault_t dax_iomap_pte_fault(struct vm_fault *vmf, pfn_t *pfnp,
 {
        struct vm_area_struct *vma = vmf->vma;
        struct address_space *mapping = vma->vm_file->f_mapping;
+       XA_STATE(xas, &mapping->i_pages, vmf->pgoff);
        struct inode *inode = mapping->host;
        unsigned long vaddr = vmf->address;
        loff_t pos = (loff_t)vmf->pgoff << PAGE_SHIFT;
@@ -1368,9 +1236,9 @@ static vm_fault_t dax_iomap_pte_fault(struct vm_fault *vmf, pfn_t *pfnp,
        if (write && !vmf->cow_page)
                flags |= IOMAP_WRITE;
 
-       entry = grab_mapping_entry(mapping, vmf->pgoff, 0);
-       if (IS_ERR(entry)) {
-               ret = dax_fault_return(PTR_ERR(entry));
+       entry = grab_mapping_entry(&xas, mapping, 0);
+       if (xa_is_internal(entry)) {
+               ret = xa_to_internal(entry);
                goto out;
        }
 
@@ -1443,7 +1311,7 @@ static vm_fault_t dax_iomap_pte_fault(struct vm_fault *vmf, pfn_t *pfnp,
                if (error < 0)
                        goto error_finish_iomap;
 
-               entry = dax_insert_mapping_entry(mapping, vmf, entry, pfn,
+               entry = dax_insert_entry(&xas, mapping, vmf, entry, pfn,
                                                 0, write && !sync);
 
                /*
@@ -1471,7 +1339,7 @@ static vm_fault_t dax_iomap_pte_fault(struct vm_fault *vmf, pfn_t *pfnp,
        case IOMAP_UNWRITTEN:
        case IOMAP_HOLE:
                if (!write) {
-                       ret = dax_load_hole(mapping, entry, vmf);
+                       ret = dax_load_hole(&xas, mapping, &entry, vmf);
                        goto finish_iomap;
                }
                /*FALLTHRU*/
@@ -1498,21 +1366,20 @@ static vm_fault_t dax_iomap_pte_fault(struct vm_fault *vmf, pfn_t *pfnp,
                ops->iomap_end(inode, pos, PAGE_SIZE, copied, flags, &iomap);
        }
  unlock_entry:
-       put_locked_mapping_entry(mapping, vmf->pgoff);
+       dax_unlock_entry(&xas, entry);
  out:
        trace_dax_pte_fault_done(inode, vmf, ret);
        return ret | major;
 }
 
 #ifdef CONFIG_FS_DAX_PMD
-static vm_fault_t dax_pmd_load_hole(struct vm_fault *vmf, struct iomap *iomap,
-               void *entry)
+static vm_fault_t dax_pmd_load_hole(struct xa_state *xas, struct vm_fault *vmf,
+               struct iomap *iomap, void **entry)
 {
        struct address_space *mapping = vmf->vma->vm_file->f_mapping;
        unsigned long pmd_addr = vmf->address & PMD_MASK;
        struct inode *inode = mapping->host;
        struct page *zero_page;
-       void *ret = NULL;
        spinlock_t *ptl;
        pmd_t pmd_entry;
        pfn_t pfn;
@@ -1523,8 +1390,8 @@ static vm_fault_t dax_pmd_load_hole(struct vm_fault *vmf, struct iomap *iomap,
                goto fallback;
 
        pfn = page_to_pfn_t(zero_page);
-       ret = dax_insert_mapping_entry(mapping, vmf, entry, pfn,
-                       RADIX_DAX_PMD | RADIX_DAX_ZERO_PAGE, false);
+       *entry = dax_insert_entry(xas, mapping, vmf, *entry, pfn,
+                       DAX_PMD | DAX_ZERO_PAGE, false);
 
        ptl = pmd_lock(vmf->vma->vm_mm, vmf->pmd);
        if (!pmd_none(*(vmf->pmd))) {
@@ -1536,11 +1403,11 @@ static vm_fault_t dax_pmd_load_hole(struct vm_fault *vmf, struct iomap *iomap,
        pmd_entry = pmd_mkhuge(pmd_entry);
        set_pmd_at(vmf->vma->vm_mm, pmd_addr, vmf->pmd, pmd_entry);
        spin_unlock(ptl);
-       trace_dax_pmd_load_hole(inode, vmf, zero_page, ret);
+       trace_dax_pmd_load_hole(inode, vmf, zero_page, *entry);
        return VM_FAULT_NOPAGE;
 
 fallback:
-       trace_dax_pmd_load_hole_fallback(inode, vmf, zero_page, ret);
+       trace_dax_pmd_load_hole_fallback(inode, vmf, zero_page, *entry);
        return VM_FAULT_FALLBACK;
 }
 
@@ -1549,6 +1416,7 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
 {
        struct vm_area_struct *vma = vmf->vma;
        struct address_space *mapping = vma->vm_file->f_mapping;
+       XA_STATE_ORDER(xas, &mapping->i_pages, vmf->pgoff, PMD_ORDER);
        unsigned long pmd_addr = vmf->address & PMD_MASK;
        bool write = vmf->flags & FAULT_FLAG_WRITE;
        bool sync;
@@ -1556,7 +1424,7 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
        struct inode *inode = mapping->host;
        vm_fault_t result = VM_FAULT_FALLBACK;
        struct iomap iomap = { 0 };
-       pgoff_t max_pgoff, pgoff;
+       pgoff_t max_pgoff;
        void *entry;
        loff_t pos;
        int error;
@@ -1567,7 +1435,6 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
         * supposed to hold locks serializing us with truncate / punch hole so
         * this is a reliable test.
         */
-       pgoff = linear_page_index(vma, pmd_addr);
        max_pgoff = DIV_ROUND_UP(i_size_read(inode), PAGE_SIZE);
 
        trace_dax_pmd_fault(inode, vmf, max_pgoff, 0);
@@ -1576,7 +1443,7 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
         * Make sure that the faulting address's PMD offset (color) matches
         * the PMD offset from the start of the file.  This is necessary so
         * that a PMD range in the page table overlaps exactly with a PMD
-        * range in the radix tree.
+        * range in the page cache.
         */
        if ((vmf->pgoff & PG_PMD_COLOUR) !=
            ((vmf->address >> PAGE_SHIFT) & PG_PMD_COLOUR))
@@ -1592,24 +1459,26 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
        if ((pmd_addr + PMD_SIZE) > vma->vm_end)
                goto fallback;
 
-       if (pgoff >= max_pgoff) {
+       if (xas.xa_index >= max_pgoff) {
                result = VM_FAULT_SIGBUS;
                goto out;
        }
 
        /* If the PMD would extend beyond the file size */
-       if ((pgoff | PG_PMD_COLOUR) >= max_pgoff)
+       if ((xas.xa_index | PG_PMD_COLOUR) >= max_pgoff)
                goto fallback;
 
        /*
-        * grab_mapping_entry() will make sure we get a 2MiB empty entry, a
-        * 2MiB zero page entry or a DAX PMD.  If it can't (because a 4k page
-        * is already in the tree, for instance), it will return -EEXIST and
-        * we just fall back to 4k entries.
+        * grab_mapping_entry() will make sure we get an empty PMD entry,
+        * a zero PMD entry or a DAX PMD.  If it can't (because a PTE
+        * entry is already in the array, for instance), it will return
+        * VM_FAULT_FALLBACK.
         */
-       entry = grab_mapping_entry(mapping, pgoff, RADIX_DAX_PMD);
-       if (IS_ERR(entry))
+       entry = grab_mapping_entry(&xas, mapping, DAX_PMD);
+       if (xa_is_internal(entry)) {
+               result = xa_to_internal(entry);
                goto fallback;
+       }
 
        /*
         * It is possible, particularly with mixed reads & writes to private
@@ -1628,7 +1497,7 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
         * setting up a mapping, so really we're using iomap_begin() as a way
         * to look up our filesystem block.
         */
-       pos = (loff_t)pgoff << PAGE_SHIFT;
+       pos = (loff_t)xas.xa_index << PAGE_SHIFT;
        error = ops->iomap_begin(inode, pos, PMD_SIZE, iomap_flags, &iomap);
        if (error)
                goto unlock_entry;
@@ -1644,8 +1513,8 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
                if (error < 0)
                        goto finish_iomap;
 
-               entry = dax_insert_mapping_entry(mapping, vmf, entry, pfn,
-                                               RADIX_DAX_PMD, write && !sync);
+               entry = dax_insert_entry(&xas, mapping, vmf, entry, pfn,
+                                               DAX_PMD, write && !sync);
 
                /*
                 * If we are doing synchronous page fault and inode needs fsync,
@@ -1669,7 +1538,7 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
        case IOMAP_HOLE:
                if (WARN_ON_ONCE(write))
                        break;
-               result = dax_pmd_load_hole(vmf, &iomap, entry);
+               result = dax_pmd_load_hole(&xas, vmf, &iomap, &entry);
                break;
        default:
                WARN_ON_ONCE(1);
@@ -1692,7 +1561,7 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
                                &iomap);
        }
  unlock_entry:
-       put_locked_mapping_entry(mapping, pgoff);
+       dax_unlock_entry(&xas, entry);
  fallback:
        if (result == VM_FAULT_FALLBACK) {
                split_huge_pmd(vma, vmf->pmd, vmf->address);
@@ -1737,54 +1606,49 @@ vm_fault_t dax_iomap_fault(struct vm_fault *vmf, enum page_entry_size pe_size,
 }
 EXPORT_SYMBOL_GPL(dax_iomap_fault);
 
-/**
+/*
  * dax_insert_pfn_mkwrite - insert PTE or PMD entry into page tables
  * @vmf: The description of the fault
- * @pe_size: Size of entry to be inserted
  * @pfn: PFN to insert
+ * @order: Order of entry to insert.
  *
- * This function inserts writeable PTE or PMD entry into page tables for mmaped
- * DAX file.  It takes care of marking corresponding radix tree entry as dirty
- * as well.
+ * This function inserts a writeable PTE or PMD entry into the page tables
+ * for an mmaped DAX file.  It also marks the page cache entry as dirty.
  */
-static vm_fault_t dax_insert_pfn_mkwrite(struct vm_fault *vmf,
-                                 enum page_entry_size pe_size,
-                                 pfn_t pfn)
+static vm_fault_t
+dax_insert_pfn_mkwrite(struct vm_fault *vmf, pfn_t pfn, unsigned int order)
 {
        struct address_space *mapping = vmf->vma->vm_file->f_mapping;
-       void *entry, **slot;
-       pgoff_t index = vmf->pgoff;
+       XA_STATE_ORDER(xas, &mapping->i_pages, vmf->pgoff, order);
+       void *entry;
        vm_fault_t ret;
 
-       xa_lock_irq(&mapping->i_pages);
-       entry = get_unlocked_mapping_entry(mapping, index, &slot);
+       xas_lock_irq(&xas);
+       entry = get_unlocked_entry(&xas);
        /* Did we race with someone splitting entry or so? */
        if (!entry ||
-           (pe_size == PE_SIZE_PTE && !dax_is_pte_entry(entry)) ||
-           (pe_size == PE_SIZE_PMD && !dax_is_pmd_entry(entry))) {
-               put_unlocked_mapping_entry(mapping, index, entry);
-               xa_unlock_irq(&mapping->i_pages);
+           (order == 0 && !dax_is_pte_entry(entry)) ||
+           (order == PMD_ORDER && (xa_is_internal(entry) ||
+                                   !dax_is_pmd_entry(entry)))) {
+               put_unlocked_entry(&xas, entry);
+               xas_unlock_irq(&xas);
                trace_dax_insert_pfn_mkwrite_no_entry(mapping->host, vmf,
                                                      VM_FAULT_NOPAGE);
                return VM_FAULT_NOPAGE;
        }
-       radix_tree_tag_set(&mapping->i_pages, index, PAGECACHE_TAG_DIRTY);
-       entry = lock_slot(mapping, slot);
-       xa_unlock_irq(&mapping->i_pages);
-       switch (pe_size) {
-       case PE_SIZE_PTE:
+       xas_set_mark(&xas, PAGECACHE_TAG_DIRTY);
+       dax_lock_entry(&xas, entry);
+       xas_unlock_irq(&xas);
+       if (order == 0)
                ret = vmf_insert_mixed_mkwrite(vmf->vma, vmf->address, pfn);
-               break;
 #ifdef CONFIG_FS_DAX_PMD
-       case PE_SIZE_PMD:
+       else if (order == PMD_ORDER)
                ret = vmf_insert_pfn_pmd(vmf->vma, vmf->address, vmf->pmd,
                        pfn, true);
-               break;
 #endif
-       default:
+       else
                ret = VM_FAULT_FALLBACK;
-       }
-       put_locked_mapping_entry(mapping, index);
+       dax_unlock_entry(&xas, entry);
        trace_dax_insert_pfn_mkwrite(mapping->host, vmf, ret);
        return ret;
 }
@@ -1804,17 +1668,12 @@ vm_fault_t dax_finish_sync_fault(struct vm_fault *vmf,
 {
        int err;
        loff_t start = ((loff_t)vmf->pgoff) << PAGE_SHIFT;
-       size_t len = 0;
+       unsigned int order = pe_order(pe_size);
+       size_t len = PAGE_SIZE << order;
 
-       if (pe_size == PE_SIZE_PTE)
-               len = PAGE_SIZE;
-       else if (pe_size == PE_SIZE_PMD)
-               len = PMD_SIZE;
-       else
-               WARN_ON_ONCE(1);
        err = vfs_fsync_range(vmf->vma->vm_file, start, start + len - 1, 1);
        if (err)
                return VM_FAULT_SIGBUS;
-       return dax_insert_pfn_mkwrite(vmf, pe_size, pfn);
+       return dax_insert_pfn_mkwrite(vmf, pfn, order);
 }
 EXPORT_SYMBOL_GPL(dax_finish_sync_fault);