dax: New fault locking
[sfrench/cifs-2.6.git] / fs / dax.c
index 75ba46d82a761cc8aced5e8ed2041a1d2c336e60..f43c3d806fb6e61ab11269fe5145787412af90d0 100644 (file)
--- a/fs/dax.c
+++ b/fs/dax.c
 #include <linux/pfn_t.h>
 #include <linux/sizes.h>
 
+/*
+ * We use lowest available bit in exceptional entry for locking, other two
+ * bits to determine entry type. In total 3 special bits.
+ */
+#define RADIX_DAX_SHIFT        (RADIX_TREE_EXCEPTIONAL_SHIFT + 3)
+#define RADIX_DAX_PTE (1 << (RADIX_TREE_EXCEPTIONAL_SHIFT + 1))
+#define RADIX_DAX_PMD (1 << (RADIX_TREE_EXCEPTIONAL_SHIFT + 2))
+#define RADIX_DAX_TYPE_MASK (RADIX_DAX_PTE | RADIX_DAX_PMD)
+#define RADIX_DAX_TYPE(entry) ((unsigned long)entry & RADIX_DAX_TYPE_MASK)
+#define RADIX_DAX_SECTOR(entry) (((unsigned long)entry >> RADIX_DAX_SHIFT))
+#define RADIX_DAX_ENTRY(sector, pmd) ((void *)((unsigned long)sector << \
+               RADIX_DAX_SHIFT | (pmd ? RADIX_DAX_PMD : RADIX_DAX_PTE) | \
+               RADIX_TREE_EXCEPTIONAL_ENTRY))
+
+/* We choose 4096 entries - same as per-zone page wait tables */
+#define DAX_WAIT_TABLE_BITS 12
+#define DAX_WAIT_TABLE_ENTRIES (1 << DAX_WAIT_TABLE_BITS)
+
+wait_queue_head_t wait_table[DAX_WAIT_TABLE_ENTRIES];
+
+static int __init init_dax_wait_table(void)
+{
+       int i;
+
+       for (i = 0; i < DAX_WAIT_TABLE_ENTRIES; i++)
+               init_waitqueue_head(wait_table + i);
+       return 0;
+}
+fs_initcall(init_dax_wait_table);
+
+static wait_queue_head_t *dax_entry_waitqueue(struct address_space *mapping,
+                                             pgoff_t index)
+{
+       unsigned long hash = hash_long((unsigned long)mapping ^ index,
+                                      DAX_WAIT_TABLE_BITS);
+       return wait_table + hash;
+}
+
 static long dax_map_atomic(struct block_device *bdev, struct blk_dax_ctl *dax)
 {
        struct request_queue *q = bdev->bd_queue;
@@ -78,50 +116,6 @@ struct page *read_dax_sector(struct block_device *bdev, sector_t n)
        return page;
 }
 
-/*
- * dax_clear_sectors() is called from within transaction context from XFS,
- * and hence this means the stack from this point must follow GFP_NOFS
- * semantics for all operations.
- */
-int dax_clear_sectors(struct block_device *bdev, sector_t _sector, long _size)
-{
-       struct blk_dax_ctl dax = {
-               .sector = _sector,
-               .size = _size,
-       };
-
-       might_sleep();
-       do {
-               long count, sz;
-
-               count = dax_map_atomic(bdev, &dax);
-               if (count < 0)
-                       return count;
-               sz = min_t(long, count, SZ_128K);
-               clear_pmem(dax.addr, sz);
-               dax.size -= sz;
-               dax.sector += sz / 512;
-               dax_unmap_atomic(bdev, &dax);
-               cond_resched();
-       } while (dax.size);
-
-       wmb_pmem();
-       return 0;
-}
-EXPORT_SYMBOL_GPL(dax_clear_sectors);
-
-/* the clear_pmem() calls are ordered by a wmb_pmem() in the caller */
-static void dax_new_buf(void __pmem *addr, unsigned size, unsigned first,
-               loff_t pos, loff_t end)
-{
-       loff_t final = end - pos + first; /* The final byte of the buffer */
-
-       if (first > 0)
-               clear_pmem(addr, first);
-       if (final < size)
-               clear_pmem(addr + final, size - final);
-}
-
 static bool buffer_written(struct buffer_head *bh)
 {
        return buffer_mapped(bh) && !buffer_unwritten(bh);
@@ -160,6 +154,9 @@ static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
        struct blk_dax_ctl dax = {
                .addr = (void __pmem *) ERR_PTR(-EIO),
        };
+       unsigned blkbits = inode->i_blkbits;
+       sector_t file_blks = (i_size_read(inode) + (1 << blkbits) - 1)
+                                                               >> blkbits;
 
        if (rw == READ)
                end = min(end, i_size_read(inode));
@@ -167,7 +164,6 @@ static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
        while (pos < end) {
                size_t len;
                if (pos == max) {
-                       unsigned blkbits = inode->i_blkbits;
                        long page = pos >> PAGE_SHIFT;
                        sector_t block = page << (PAGE_SHIFT - blkbits);
                        unsigned first = pos - (block << blkbits);
@@ -183,6 +179,13 @@ static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
                                        bh->b_size = 1 << blkbits;
                                bh_max = pos - first + bh->b_size;
                                bdev = bh->b_bdev;
+                               /*
+                                * We allow uninitialized buffers for writes
+                                * beyond EOF as those cannot race with faults
+                                */
+                               WARN_ON_ONCE(
+                                       (buffer_new(bh) && block < file_blks) ||
+                                       (rw == WRITE && buffer_unwritten(bh)));
                        } else {
                                unsigned done = bh->b_size -
                                                (bh_max - (pos - first));
@@ -202,11 +205,6 @@ static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
                                        rc = map_len;
                                        break;
                                }
-                               if (buffer_unwritten(bh) || buffer_new(bh)) {
-                                       dax_new_buf(dax.addr, map_len, first,
-                                                       pos, end);
-                                       need_wmb = true;
-                               }
                                dax.addr += first;
                                size = map_len - first;
                        }
@@ -267,15 +265,8 @@ ssize_t dax_do_io(struct kiocb *iocb, struct inode *inode,
        memset(&bh, 0, sizeof(bh));
        bh.b_bdev = inode->i_sb->s_bdev;
 
-       if ((flags & DIO_LOCKING) && iov_iter_rw(iter) == READ) {
-               struct address_space *mapping = inode->i_mapping;
+       if ((flags & DIO_LOCKING) && iov_iter_rw(iter) == READ)
                inode_lock(inode);
-               retval = filemap_write_and_wait_range(mapping, pos, end - 1);
-               if (retval) {
-                       inode_unlock(inode);
-                       goto out;
-               }
-       }
 
        /* Protects against truncate */
        if (!(flags & DIO_SKIP_DIO_COUNT))
@@ -296,11 +287,267 @@ ssize_t dax_do_io(struct kiocb *iocb, struct inode *inode,
 
        if (!(flags & DIO_SKIP_DIO_COUNT))
                inode_dio_end(inode);
- out:
        return retval;
 }
 EXPORT_SYMBOL_GPL(dax_do_io);
 
+/*
+ * DAX radix tree locking
+ */
+struct exceptional_entry_key {
+       struct address_space *mapping;
+       unsigned long index;
+};
+
+struct wait_exceptional_entry_queue {
+       wait_queue_t wait;
+       struct exceptional_entry_key key;
+};
+
+static int wake_exceptional_entry_func(wait_queue_t *wait, unsigned int mode,
+                                      int sync, void *keyp)
+{
+       struct exceptional_entry_key *key = keyp;
+       struct wait_exceptional_entry_queue *ewait =
+               container_of(wait, struct wait_exceptional_entry_queue, wait);
+
+       if (key->mapping != ewait->key.mapping ||
+           key->index != ewait->key.index)
+               return 0;
+       return autoremove_wake_function(wait, mode, sync, NULL);
+}
+
+/*
+ * Check whether the given slot is locked. The function must be called with
+ * mapping->tree_lock held
+ */
+static inline int slot_locked(struct address_space *mapping, void **slot)
+{
+       unsigned long entry = (unsigned long)
+               radix_tree_deref_slot_protected(slot, &mapping->tree_lock);
+       return entry & RADIX_DAX_ENTRY_LOCK;
+}
+
+/*
+ * Mark the given slot is locked. The function must be called with
+ * mapping->tree_lock held
+ */
+static inline void *lock_slot(struct address_space *mapping, void **slot)
+{
+       unsigned long entry = (unsigned long)
+               radix_tree_deref_slot_protected(slot, &mapping->tree_lock);
+
+       entry |= RADIX_DAX_ENTRY_LOCK;
+       radix_tree_replace_slot(slot, (void *)entry);
+       return (void *)entry;
+}
+
+/*
+ * Mark the given slot is unlocked. The function must be called with
+ * mapping->tree_lock held
+ */
+static inline void *unlock_slot(struct address_space *mapping, void **slot)
+{
+       unsigned long entry = (unsigned long)
+               radix_tree_deref_slot_protected(slot, &mapping->tree_lock);
+
+       entry &= ~(unsigned long)RADIX_DAX_ENTRY_LOCK;
+       radix_tree_replace_slot(slot, (void *)entry);
+       return (void *)entry;
+}
+
+/*
+ * Lookup entry in radix tree, wait for it to become unlocked if it is
+ * exceptional entry and return it. The caller must call
+ * put_unlocked_mapping_entry() when he decided not to lock the entry or
+ * put_locked_mapping_entry() when he locked the entry and now wants to
+ * unlock it.
+ *
+ * The function must be called with mapping->tree_lock held.
+ */
+static void *get_unlocked_mapping_entry(struct address_space *mapping,
+                                       pgoff_t index, void ***slotp)
+{
+       void *ret, **slot;
+       struct wait_exceptional_entry_queue ewait;
+       wait_queue_head_t *wq = dax_entry_waitqueue(mapping, index);
+
+       init_wait(&ewait.wait);
+       ewait.wait.func = wake_exceptional_entry_func;
+       ewait.key.mapping = mapping;
+       ewait.key.index = index;
+
+       for (;;) {
+               ret = __radix_tree_lookup(&mapping->page_tree, index, NULL,
+                                         &slot);
+               if (!ret || !radix_tree_exceptional_entry(ret) ||
+                   !slot_locked(mapping, slot)) {
+                       if (slotp)
+                               *slotp = slot;
+                       return ret;
+               }
+               prepare_to_wait_exclusive(wq, &ewait.wait,
+                                         TASK_UNINTERRUPTIBLE);
+               spin_unlock_irq(&mapping->tree_lock);
+               schedule();
+               finish_wait(wq, &ewait.wait);
+               spin_lock_irq(&mapping->tree_lock);
+       }
+}
+
+/*
+ * Find radix tree entry at given index. If it points to a page, return with
+ * the page locked. If it points to the exceptional entry, return with the
+ * radix tree entry locked. If the radix tree doesn't contain given index,
+ * create empty exceptional entry for the index and return with it locked.
+ *
+ * Note: Unlike filemap_fault() we don't honor FAULT_FLAG_RETRY flags. For
+ * persistent memory the benefit is doubtful. We can add that later if we can
+ * show it helps.
+ */
+static void *grab_mapping_entry(struct address_space *mapping, pgoff_t index)
+{
+       void *ret, **slot;
+
+restart:
+       spin_lock_irq(&mapping->tree_lock);
+       ret = get_unlocked_mapping_entry(mapping, index, &slot);
+       /* No entry for given index? Make sure radix tree is big enough. */
+       if (!ret) {
+               int err;
+
+               spin_unlock_irq(&mapping->tree_lock);
+               err = radix_tree_preload(
+                               mapping_gfp_mask(mapping) & ~__GFP_HIGHMEM);
+               if (err)
+                       return ERR_PTR(err);
+               ret = (void *)(RADIX_TREE_EXCEPTIONAL_ENTRY |
+                              RADIX_DAX_ENTRY_LOCK);
+               spin_lock_irq(&mapping->tree_lock);
+               err = radix_tree_insert(&mapping->page_tree, index, ret);
+               radix_tree_preload_end();
+               if (err) {
+                       spin_unlock_irq(&mapping->tree_lock);
+                       /* Someone already created the entry? */
+                       if (err == -EEXIST)
+                               goto restart;
+                       return ERR_PTR(err);
+               }
+               /* Good, we have inserted empty locked entry into the tree. */
+               mapping->nrexceptional++;
+               spin_unlock_irq(&mapping->tree_lock);
+               return ret;
+       }
+       /* Normal page in radix tree? */
+       if (!radix_tree_exceptional_entry(ret)) {
+               struct page *page = ret;
+
+               get_page(page);
+               spin_unlock_irq(&mapping->tree_lock);
+               lock_page(page);
+               /* Page got truncated? Retry... */
+               if (unlikely(page->mapping != mapping)) {
+                       unlock_page(page);
+                       put_page(page);
+                       goto restart;
+               }
+               return page;
+       }
+       ret = lock_slot(mapping, slot);
+       spin_unlock_irq(&mapping->tree_lock);
+       return ret;
+}
+
+void dax_wake_mapping_entry_waiter(struct address_space *mapping,
+                                  pgoff_t index, bool wake_all)
+{
+       wait_queue_head_t *wq = dax_entry_waitqueue(mapping, index);
+
+       /*
+        * Checking for locked entry and prepare_to_wait_exclusive() happens
+        * under mapping->tree_lock, ditto for entry handling in our callers.
+        * So at this point all tasks that could have seen our entry locked
+        * must be in the waitqueue and the following check will see them.
+        */
+       if (waitqueue_active(wq)) {
+               struct exceptional_entry_key key;
+
+               key.mapping = mapping;
+               key.index = index;
+               __wake_up(wq, TASK_NORMAL, wake_all ? 0 : 1, &key);
+       }
+}
+
+static void unlock_mapping_entry(struct address_space *mapping, pgoff_t index)
+{
+       void *ret, **slot;
+
+       spin_lock_irq(&mapping->tree_lock);
+       ret = __radix_tree_lookup(&mapping->page_tree, index, NULL, &slot);
+       if (WARN_ON_ONCE(!ret || !radix_tree_exceptional_entry(ret) ||
+                        !slot_locked(mapping, slot))) {
+               spin_unlock_irq(&mapping->tree_lock);
+               return;
+       }
+       unlock_slot(mapping, slot);
+       spin_unlock_irq(&mapping->tree_lock);
+       dax_wake_mapping_entry_waiter(mapping, index, false);
+}
+
+static void put_locked_mapping_entry(struct address_space *mapping,
+                                    pgoff_t index, void *entry)
+{
+       if (!radix_tree_exceptional_entry(entry)) {
+               unlock_page(entry);
+               put_page(entry);
+       } else {
+               unlock_mapping_entry(mapping, index);
+       }
+}
+
+/*
+ * Called when we are done with radix tree entry we looked up via
+ * get_unlocked_mapping_entry() and which we didn't lock in the end.
+ */
+static void put_unlocked_mapping_entry(struct address_space *mapping,
+                                      pgoff_t index, void *entry)
+{
+       if (!radix_tree_exceptional_entry(entry))
+               return;
+
+       /* We have to wake up next waiter for the radix tree entry lock */
+       dax_wake_mapping_entry_waiter(mapping, index, false);
+}
+
+/*
+ * Delete exceptional DAX entry at @index from @mapping. Wait for radix tree
+ * entry to get unlocked before deleting it.
+ */
+int dax_delete_mapping_entry(struct address_space *mapping, pgoff_t index)
+{
+       void *entry;
+
+       spin_lock_irq(&mapping->tree_lock);
+       entry = get_unlocked_mapping_entry(mapping, index, NULL);
+       /*
+        * This gets called from truncate / punch_hole path. As such, the caller
+        * must hold locks protecting against concurrent modifications of the
+        * radix tree (usually fs-private i_mmap_sem for writing). Since the
+        * caller has seen exceptional entry for this index, we better find it
+        * at that index as well...
+        */
+       if (WARN_ON_ONCE(!entry || !radix_tree_exceptional_entry(entry))) {
+               spin_unlock_irq(&mapping->tree_lock);
+               return 0;
+       }
+       radix_tree_delete(&mapping->page_tree, index);
+       mapping->nrexceptional--;
+       spin_unlock_irq(&mapping->tree_lock);
+       dax_wake_mapping_entry_waiter(mapping, index, true);
+
+       return 1;
+}
+
 /*
  * The user has performed a load from a hole in the file.  Allocating
  * a new page in the file would cause excessive storage usage for
@@ -309,24 +556,24 @@ EXPORT_SYMBOL_GPL(dax_do_io);
  * otherwise it will simply fall out of the page cache under memory
  * pressure without ever having been dirtied.
  */
-static int dax_load_hole(struct address_space *mapping, struct page *page,
-                                                       struct vm_fault *vmf)
+static int dax_load_hole(struct address_space *mapping, void *entry,
+                        struct vm_fault *vmf)
 {
-       unsigned long size;
-       struct inode *inode = mapping->host;
-       if (!page)
-               page = find_or_create_page(mapping, vmf->pgoff,
-                                               GFP_KERNEL | __GFP_ZERO);
-       if (!page)
-               return VM_FAULT_OOM;
-       /* Recheck i_size under page lock to avoid truncate race */
-       size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
-       if (vmf->pgoff >= size) {
-               unlock_page(page);
-               put_page(page);
-               return VM_FAULT_SIGBUS;
+       struct page *page;
+
+       /* Hole page already exists? Return it...  */
+       if (!radix_tree_exceptional_entry(entry)) {
+               vmf->page = entry;
+               return VM_FAULT_LOCKED;
        }
 
+       /* This will replace locked radix tree entry with a hole page */
+       page = find_or_create_page(mapping, vmf->pgoff,
+                                  vmf->gfp_mask | __GFP_ZERO);
+       if (!page) {
+               put_locked_mapping_entry(mapping, vmf->pgoff, entry);
+               return VM_FAULT_OOM;
+       }
        vmf->page = page;
        return VM_FAULT_LOCKED;
 }
@@ -350,77 +597,72 @@ static int copy_user_bh(struct page *to, struct inode *inode,
        return 0;
 }
 
-#define NO_SECTOR -1
 #define DAX_PMD_INDEX(page_index) (page_index & (PMD_MASK >> PAGE_SHIFT))
 
-static int dax_radix_entry(struct address_space *mapping, pgoff_t index,
-               sector_t sector, bool pmd_entry, bool dirty)
+static void *dax_insert_mapping_entry(struct address_space *mapping,
+                                     struct vm_fault *vmf,
+                                     void *entry, sector_t sector)
 {
        struct radix_tree_root *page_tree = &mapping->page_tree;
-       pgoff_t pmd_index = DAX_PMD_INDEX(index);
-       int type, error = 0;
-       void *entry;
+       int error = 0;
+       bool hole_fill = false;
+       void *new_entry;
+       pgoff_t index = vmf->pgoff;
 
-       WARN_ON_ONCE(pmd_entry && !dirty);
-       if (dirty)
+       if (vmf->flags & FAULT_FLAG_WRITE)
                __mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
 
-       spin_lock_irq(&mapping->tree_lock);
-
-       entry = radix_tree_lookup(page_tree, pmd_index);
-       if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD) {
-               index = pmd_index;
-               goto dirty;
+       /* Replacing hole page with block mapping? */
+       if (!radix_tree_exceptional_entry(entry)) {
+               hole_fill = true;
+               /*
+                * Unmap the page now before we remove it from page cache below.
+                * The page is locked so it cannot be faulted in again.
+                */
+               unmap_mapping_range(mapping, vmf->pgoff << PAGE_SHIFT,
+                                   PAGE_SIZE, 0);
+               error = radix_tree_preload(vmf->gfp_mask & ~__GFP_HIGHMEM);
+               if (error)
+                       return ERR_PTR(error);
        }
 
-       entry = radix_tree_lookup(page_tree, index);
-       if (entry) {
-               type = RADIX_DAX_TYPE(entry);
-               if (WARN_ON_ONCE(type != RADIX_DAX_PTE &&
-                                       type != RADIX_DAX_PMD)) {
-                       error = -EIO;
+       spin_lock_irq(&mapping->tree_lock);
+       new_entry = (void *)((unsigned long)RADIX_DAX_ENTRY(sector, false) |
+                      RADIX_DAX_ENTRY_LOCK);
+       if (hole_fill) {
+               __delete_from_page_cache(entry, NULL);
+               /* Drop pagecache reference */
+               put_page(entry);
+               error = radix_tree_insert(page_tree, index, new_entry);
+               if (error) {
+                       new_entry = ERR_PTR(error);
                        goto unlock;
                }
+               mapping->nrexceptional++;
+       } else {
+               void **slot;
+               void *ret;
 
-               if (!pmd_entry || type == RADIX_DAX_PMD)
-                       goto dirty;
-
-               /*
-                * We only insert dirty PMD entries into the radix tree.  This
-                * means we don't need to worry about removing a dirty PTE
-                * entry and inserting a clean PMD entry, thus reducing the
-                * range we would flush with a follow-up fsync/msync call.
-                */
-               radix_tree_delete(&mapping->page_tree, index);
-               mapping->nrexceptional--;
-       }
-
-       if (sector == NO_SECTOR) {
-               /*
-                * This can happen during correct operation if our pfn_mkwrite
-                * fault raced against a hole punch operation.  If this
-                * happens the pte that was hole punched will have been
-                * unmapped and the radix tree entry will have been removed by
-                * the time we are called, but the call will still happen.  We
-                * will return all the way up to wp_pfn_shared(), where the
-                * pte_same() check will fail, eventually causing page fault
-                * to be retried by the CPU.
-                */
-               goto unlock;
+               ret = __radix_tree_lookup(page_tree, index, NULL, &slot);
+               WARN_ON_ONCE(ret != entry);
+               radix_tree_replace_slot(slot, new_entry);
        }
-
-       error = radix_tree_insert(page_tree, index,
-                       RADIX_DAX_ENTRY(sector, pmd_entry));
-       if (error)
-               goto unlock;
-
-       mapping->nrexceptional++;
- dirty:
-       if (dirty)
+       if (vmf->flags & FAULT_FLAG_WRITE)
                radix_tree_tag_set(page_tree, index, PAGECACHE_TAG_DIRTY);
  unlock:
        spin_unlock_irq(&mapping->tree_lock);
-       return error;
+       if (hole_fill) {
+               radix_tree_preload_end();
+               /*
+                * We don't need hole page anymore, it has been replaced with
+                * locked radix tree entry now.
+                */
+               if (mapping->a_ops->freepage)
+                       mapping->a_ops->freepage(entry);
+               unlock_page(entry);
+               put_page(entry);
+       }
+       return new_entry;
 }
 
 static int dax_writeback_one(struct block_device *bdev,
@@ -546,55 +788,38 @@ int dax_writeback_mapping_range(struct address_space *mapping,
 }
 EXPORT_SYMBOL_GPL(dax_writeback_mapping_range);
 
-static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
+static int dax_insert_mapping(struct address_space *mapping,
+                       struct buffer_head *bh, void **entryp,
                        struct vm_area_struct *vma, struct vm_fault *vmf)
 {
        unsigned long vaddr = (unsigned long)vmf->virtual_address;
-       struct address_space *mapping = inode->i_mapping;
        struct block_device *bdev = bh->b_bdev;
        struct blk_dax_ctl dax = {
-               .sector = to_sector(bh, inode),
+               .sector = to_sector(bh, mapping->host),
                .size = bh->b_size,
        };
-       pgoff_t size;
        int error;
+       void *ret;
+       void *entry = *entryp;
 
        i_mmap_lock_read(mapping);
 
-       /*
-        * Check truncate didn't happen while we were allocating a block.
-        * If it did, this block may or may not be still allocated to the
-        * file.  We can't tell the filesystem to free it because we can't
-        * take i_mutex here.  In the worst case, the file still has blocks
-        * allocated past the end of the file.
-        */
-       size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
-       if (unlikely(vmf->pgoff >= size)) {
-               error = -EIO;
-               goto out;
-       }
-
        if (dax_map_atomic(bdev, &dax) < 0) {
                error = PTR_ERR(dax.addr);
                goto out;
        }
-
-       if (buffer_unwritten(bh) || buffer_new(bh)) {
-               clear_pmem(dax.addr, PAGE_SIZE);
-               wmb_pmem();
-       }
        dax_unmap_atomic(bdev, &dax);
 
-       error = dax_radix_entry(mapping, vmf->pgoff, dax.sector, false,
-                       vmf->flags & FAULT_FLAG_WRITE);
-       if (error)
+       ret = dax_insert_mapping_entry(mapping, vmf, entry, dax.sector);
+       if (IS_ERR(ret)) {
+               error = PTR_ERR(ret);
                goto out;
+       }
+       *entryp = ret;
 
        error = vm_insert_mixed(vma, vaddr, dax.pfn);
-
  out:
        i_mmap_unlock_read(mapping);
-
        return error;
 }
 
@@ -603,24 +828,18 @@ static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
  * @vma: The virtual memory area where the fault occurred
  * @vmf: The description of the fault
  * @get_block: The filesystem method used to translate file offsets to blocks
- * @complete_unwritten: The filesystem method used to convert unwritten blocks
- *     to written so the data written to them is exposed. This is required for
- *     required by write faults for filesystems that will return unwritten
- *     extent mappings from @get_block, but it is optional for reads as
- *     dax_insert_mapping() will always zero unwritten blocks. If the fs does
- *     not support unwritten extents, the it should pass NULL.
  *
  * When a page fault occurs, filesystems may call this helper in their
  * fault handler for DAX files. __dax_fault() assumes the caller has done all
  * the necessary locking for the page fault to proceed successfully.
  */
 int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
-                       get_block_t get_block, dax_iodone_t complete_unwritten)
+                       get_block_t get_block)
 {
        struct file *file = vma->vm_file;
        struct address_space *mapping = file->f_mapping;
        struct inode *inode = mapping->host;
-       struct page *page;
+       void *entry;
        struct buffer_head bh;
        unsigned long vaddr = (unsigned long)vmf->virtual_address;
        unsigned blkbits = inode->i_blkbits;
@@ -629,6 +848,11 @@ int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
        int error;
        int major = 0;
 
+       /*
+        * Check whether offset isn't beyond end of file now. Caller is supposed
+        * to hold locks serializing us with truncate / punch hole so this is
+        * a reliable test.
+        */
        size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
        if (vmf->pgoff >= size)
                return VM_FAULT_SIGBUS;
@@ -638,49 +862,17 @@ int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
        bh.b_bdev = inode->i_sb->s_bdev;
        bh.b_size = PAGE_SIZE;
 
- repeat:
-       page = find_get_page(mapping, vmf->pgoff);
-       if (page) {
-               if (!lock_page_or_retry(page, vma->vm_mm, vmf->flags)) {
-                       put_page(page);
-                       return VM_FAULT_RETRY;
-               }
-               if (unlikely(page->mapping != mapping)) {
-                       unlock_page(page);
-                       put_page(page);
-                       goto repeat;
-               }
-               size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
-               if (unlikely(vmf->pgoff >= size)) {
-                       /*
-                        * We have a struct page covering a hole in the file
-                        * from a read fault and we've raced with a truncate
-                        */
-                       error = -EIO;
-                       goto unlock_page;
-               }
+       entry = grab_mapping_entry(mapping, vmf->pgoff);
+       if (IS_ERR(entry)) {
+               error = PTR_ERR(entry);
+               goto out;
        }
 
        error = get_block(inode, block, &bh, 0);
        if (!error && (bh.b_size < PAGE_SIZE))
                error = -EIO;           /* fs corruption? */
        if (error)
-               goto unlock_page;
-
-       if (!buffer_mapped(&bh) && !buffer_unwritten(&bh) && !vmf->cow_page) {
-               if (vmf->flags & FAULT_FLAG_WRITE) {
-                       error = get_block(inode, block, &bh, 1);
-                       count_vm_event(PGMAJFAULT);
-                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-                       major = VM_FAULT_MAJOR;
-                       if (!error && (bh.b_size < PAGE_SIZE))
-                               error = -EIO;
-                       if (error)
-                               goto unlock_page;
-               } else {
-                       return dax_load_hole(mapping, page, vmf);
-               }
-       }
+               goto unlock_entry;
 
        if (vmf->cow_page) {
                struct page *new_page = vmf->cow_page;
@@ -689,53 +881,37 @@ int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                else
                        clear_user_highpage(new_page, vaddr);
                if (error)
-                       goto unlock_page;
-               vmf->page = page;
-               if (!page) {
+                       goto unlock_entry;
+               if (!radix_tree_exceptional_entry(entry)) {
+                       vmf->page = entry;
+               } else {
+                       unlock_mapping_entry(mapping, vmf->pgoff);
                        i_mmap_lock_read(mapping);
-                       /* Check we didn't race with truncate */
-                       size = (i_size_read(inode) + PAGE_SIZE - 1) >>
-                                                               PAGE_SHIFT;
-                       if (vmf->pgoff >= size) {
-                               i_mmap_unlock_read(mapping);
-                               error = -EIO;
-                               goto out;
-                       }
+                       vmf->page = NULL;
                }
                return VM_FAULT_LOCKED;
        }
 
-       /* Check we didn't race with a read fault installing a new page */
-       if (!page && major)
-               page = find_lock_page(mapping, vmf->pgoff);
-
-       if (page) {
-               unmap_mapping_range(mapping, vmf->pgoff << PAGE_SHIFT,
-                                                       PAGE_SIZE, 0);
-               delete_from_page_cache(page);
-               unlock_page(page);
-               put_page(page);
-               page = NULL;
-       }
-
-       /*
-        * If we successfully insert the new mapping over an unwritten extent,
-        * we need to ensure we convert the unwritten extent. If there is an
-        * error inserting the mapping, the filesystem needs to leave it as
-        * unwritten to prevent exposure of the stale underlying data to
-        * userspace, but we still need to call the completion function so
-        * the private resources on the mapping buffer can be released. We
-        * indicate what the callback should do via the uptodate variable, same
-        * as for normal BH based IO completions.
-        */
-       error = dax_insert_mapping(inode, &bh, vma, vmf);
-       if (buffer_unwritten(&bh)) {
-               if (complete_unwritten)
-                       complete_unwritten(&bh, !error);
-               else
-                       WARN_ON_ONCE(!(vmf->flags & FAULT_FLAG_WRITE));
+       if (!buffer_mapped(&bh)) {
+               if (vmf->flags & FAULT_FLAG_WRITE) {
+                       error = get_block(inode, block, &bh, 1);
+                       count_vm_event(PGMAJFAULT);
+                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
+                       major = VM_FAULT_MAJOR;
+                       if (!error && (bh.b_size < PAGE_SIZE))
+                               error = -EIO;
+                       if (error)
+                               goto unlock_entry;
+               } else {
+                       return dax_load_hole(mapping, entry, vmf);
+               }
        }
 
+       /* Filesystem should not return unwritten buffers to us! */
+       WARN_ON_ONCE(buffer_unwritten(&bh) || buffer_new(&bh));
+       error = dax_insert_mapping(mapping, &bh, &entry, vma, vmf);
+ unlock_entry:
+       put_locked_mapping_entry(mapping, vmf->pgoff, entry);
  out:
        if (error == -ENOMEM)
                return VM_FAULT_OOM | major;
@@ -743,13 +919,6 @@ int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
        if ((error < 0) && (error != -EBUSY))
                return VM_FAULT_SIGBUS | major;
        return VM_FAULT_NOPAGE | major;
-
- unlock_page:
-       if (page) {
-               unlock_page(page);
-               put_page(page);
-       }
-       goto out;
 }
 EXPORT_SYMBOL(__dax_fault);
 
@@ -763,7 +932,7 @@ EXPORT_SYMBOL(__dax_fault);
  * fault handler for DAX files.
  */
 int dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
-             get_block_t get_block, dax_iodone_t complete_unwritten)
+             get_block_t get_block)
 {
        int result;
        struct super_block *sb = file_inode(vma->vm_file)->i_sb;
@@ -772,7 +941,7 @@ int dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                sb_start_pagefault(sb);
                file_update_time(vma->vm_file);
        }
-       result = __dax_fault(vma, vmf, get_block, complete_unwritten);
+       result = __dax_fault(vma, vmf, get_block);
        if (vmf->flags & FAULT_FLAG_WRITE)
                sb_end_pagefault(sb);
 
@@ -780,7 +949,7 @@ int dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 }
 EXPORT_SYMBOL_GPL(dax_fault);
 
-#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+#if defined(CONFIG_TRANSPARENT_HUGEPAGE)
 /*
  * The 'colour' (ie low bits) within a PMD of a page offset.  This comes up
  * more often than one might expect in the below function.
@@ -806,8 +975,7 @@ static void __dax_dbg(struct buffer_head *bh, unsigned long address,
 #define dax_pmd_dbg(bh, address, reason)       __dax_dbg(bh, address, reason, "dax_pmd")
 
 int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
-               pmd_t *pmd, unsigned int flags, get_block_t get_block,
-               dax_iodone_t complete_unwritten)
+               pmd_t *pmd, unsigned int flags, get_block_t get_block)
 {
        struct file *file = vma->vm_file;
        struct address_space *mapping = file->f_mapping;
@@ -819,7 +987,7 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
        struct block_device *bdev;
        pgoff_t size, pgoff;
        sector_t block;
-       int error, result = 0;
+       int result = 0;
        bool alloc = false;
 
        /* dax pmd mappings require pfn_t_devmap() */
@@ -866,6 +1034,7 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
                if (get_block(inode, block, &bh, 1) != 0)
                        return VM_FAULT_SIGBUS;
                alloc = true;
+               WARN_ON_ONCE(buffer_unwritten(&bh) || buffer_new(&bh));
        }
 
        bdev = bh.b_bdev;
@@ -893,24 +1062,7 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 
        i_mmap_lock_read(mapping);
 
-       /*
-        * If a truncate happened while we were allocating blocks, we may
-        * leave blocks allocated to the file that are beyond EOF.  We can't
-        * take i_mutex here, so just leave them hanging; they'll be freed
-        * when the file is deleted.
-        */
-       size = (i_size_read(inode) + PAGE_SIZE - 1) >> PAGE_SHIFT;
-       if (pgoff >= size) {
-               result = VM_FAULT_SIGBUS;
-               goto out;
-       }
-       if ((pgoff | PG_PMD_COLOUR) >= size) {
-               dax_pmd_dbg(&bh, address,
-                               "offset + huge page size > file size");
-               goto fallback;
-       }
-
-       if (!write && !buffer_mapped(&bh) && buffer_uptodate(&bh)) {
+       if (!write && !buffer_mapped(&bh)) {
                spinlock_t *ptl;
                pmd_t entry;
                struct page *zero_page = get_huge_zero_page();
@@ -945,8 +1097,8 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
                long length = dax_map_atomic(bdev, &dax);
 
                if (length < 0) {
-                       result = VM_FAULT_SIGBUS;
-                       goto out;
+                       dax_pmd_dbg(&bh, address, "dax-error fallback");
+                       goto fallback;
                }
                if (length < PMD_SIZE) {
                        dax_pmd_dbg(&bh, address, "dax-length too small");
@@ -964,14 +1116,6 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
                        dax_pmd_dbg(&bh, address, "pfn not in memmap");
                        goto fallback;
                }
-
-               if (buffer_unwritten(&bh) || buffer_new(&bh)) {
-                       clear_pmem(dax.addr, PMD_SIZE);
-                       wmb_pmem();
-                       count_vm_event(PGMAJFAULT);
-                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-                       result |= VM_FAULT_MAJOR;
-               }
                dax_unmap_atomic(bdev, &dax);
 
                /*
@@ -990,13 +1134,10 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
                 * the write to insert a dirty entry.
                 */
                if (write) {
-                       error = dax_radix_entry(mapping, pgoff, dax.sector,
-                                       true, true);
-                       if (error) {
-                               dax_pmd_dbg(&bh, address,
-                                               "PMD radix insertion failed");
-                               goto fallback;
-                       }
+                       /*
+                        * We should insert radix-tree entry and dirty it here.
+                        * For now this is broken...
+                        */
                }
 
                dev_dbg(part_to_dev(bdev->bd_part),
@@ -1011,9 +1152,6 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
  out:
        i_mmap_unlock_read(mapping);
 
-       if (buffer_unwritten(&bh))
-               complete_unwritten(&bh, !(result & VM_FAULT_ERROR));
-
        return result;
 
  fallback:
@@ -1033,8 +1171,7 @@ EXPORT_SYMBOL_GPL(__dax_pmd_fault);
  * pmd_fault handler for DAX files.
  */
 int dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
-                       pmd_t *pmd, unsigned int flags, get_block_t get_block,
-                       dax_iodone_t complete_unwritten)
+                       pmd_t *pmd, unsigned int flags, get_block_t get_block)
 {
        int result;
        struct super_block *sb = file_inode(vma->vm_file)->i_sb;
@@ -1043,8 +1180,7 @@ int dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
                sb_start_pagefault(sb);
                file_update_time(vma->vm_file);
        }
-       result = __dax_pmd_fault(vma, address, pmd, flags, get_block,
-                               complete_unwritten);
+       result = __dax_pmd_fault(vma, address, pmd, flags, get_block);
        if (flags & FAULT_FLAG_WRITE)
                sb_end_pagefault(sb);
 
@@ -1061,27 +1197,59 @@ EXPORT_SYMBOL_GPL(dax_pmd_fault);
 int dax_pfn_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
        struct file *file = vma->vm_file;
-       int error;
-
-       /*
-        * We pass NO_SECTOR to dax_radix_entry() because we expect that a
-        * RADIX_DAX_PTE entry already exists in the radix tree from a
-        * previous call to __dax_fault().  We just want to look up that PTE
-        * entry using vmf->pgoff and make sure the dirty tag is set.  This
-        * saves us from having to make a call to get_block() here to look
-        * up the sector.
-        */
-       error = dax_radix_entry(file->f_mapping, vmf->pgoff, NO_SECTOR, false,
-                       true);
+       struct address_space *mapping = file->f_mapping;
+       void *entry;
+       pgoff_t index = vmf->pgoff;
 
-       if (error == -ENOMEM)
-               return VM_FAULT_OOM;
-       if (error)
-               return VM_FAULT_SIGBUS;
+       spin_lock_irq(&mapping->tree_lock);
+       entry = get_unlocked_mapping_entry(mapping, index, NULL);
+       if (!entry || !radix_tree_exceptional_entry(entry))
+               goto out;
+       radix_tree_tag_set(&mapping->page_tree, index, PAGECACHE_TAG_DIRTY);
+       put_unlocked_mapping_entry(mapping, index, entry);
+out:
+       spin_unlock_irq(&mapping->tree_lock);
        return VM_FAULT_NOPAGE;
 }
 EXPORT_SYMBOL_GPL(dax_pfn_mkwrite);
 
+static bool dax_range_is_aligned(struct block_device *bdev,
+                                unsigned int offset, unsigned int length)
+{
+       unsigned short sector_size = bdev_logical_block_size(bdev);
+
+       if (!IS_ALIGNED(offset, sector_size))
+               return false;
+       if (!IS_ALIGNED(length, sector_size))
+               return false;
+
+       return true;
+}
+
+int __dax_zero_page_range(struct block_device *bdev, sector_t sector,
+               unsigned int offset, unsigned int length)
+{
+       struct blk_dax_ctl dax = {
+               .sector         = sector,
+               .size           = PAGE_SIZE,
+       };
+
+       if (dax_range_is_aligned(bdev, offset, length)) {
+               sector_t start_sector = dax.sector + (offset >> 9);
+
+               return blkdev_issue_zeroout(bdev, start_sector,
+                               length >> 9, GFP_NOFS, true);
+       } else {
+               if (dax_map_atomic(bdev, &dax) < 0)
+                       return PTR_ERR(dax.addr);
+               clear_pmem(dax.addr + offset, length);
+               wmb_pmem();
+               dax_unmap_atomic(bdev, &dax);
+       }
+       return 0;
+}
+EXPORT_SYMBOL_GPL(__dax_zero_page_range);
+
 /**
  * dax_zero_page_range - zero a range within a page of a DAX file
  * @inode: The file being truncated
@@ -1093,12 +1261,6 @@ EXPORT_SYMBOL_GPL(dax_pfn_mkwrite);
  * page in a DAX file.  This is intended for hole-punch operations.  If
  * you are truncating a file, the helper function dax_truncate_page() may be
  * more convenient.
- *
- * We work in terms of PAGE_SIZE here for commonality with
- * block_truncate_page(), but we could go down to PAGE_SIZE if the filesystem
- * took care of disposing of the unnecessary blocks.  Even if the filesystem
- * block size is smaller than PAGE_SIZE, we have to zero the rest of the page
- * since the file might be mmapped.
  */
 int dax_zero_page_range(struct inode *inode, loff_t from, unsigned length,
                                                        get_block_t get_block)
@@ -1117,23 +1279,11 @@ int dax_zero_page_range(struct inode *inode, loff_t from, unsigned length,
        bh.b_bdev = inode->i_sb->s_bdev;
        bh.b_size = PAGE_SIZE;
        err = get_block(inode, index, &bh, 0);
-       if (err < 0)
+       if (err < 0 || !buffer_written(&bh))
                return err;
-       if (buffer_written(&bh)) {
-               struct block_device *bdev = bh.b_bdev;
-               struct blk_dax_ctl dax = {
-                       .sector = to_sector(&bh, inode),
-                       .size = PAGE_SIZE,
-               };
 
-               if (dax_map_atomic(bdev, &dax) < 0)
-                       return PTR_ERR(dax.addr);
-               clear_pmem(dax.addr + offset, length);
-               wmb_pmem();
-               dax_unmap_atomic(bdev, &dax);
-       }
-
-       return 0;
+       return __dax_zero_page_range(bh.b_bdev, to_sector(&bh, inode),
+                       offset, length);
 }
 EXPORT_SYMBOL_GPL(dax_zero_page_range);
 
@@ -1145,12 +1295,6 @@ EXPORT_SYMBOL_GPL(dax_zero_page_range);
  *
  * Similar to block_truncate_page(), this function can be called by a
  * filesystem when it is truncating a DAX file to handle the partial page.
- *
- * We work in terms of PAGE_SIZE here for commonality with
- * block_truncate_page(), but we could go down to PAGE_SIZE if the filesystem
- * took care of disposing of the unnecessary blocks.  Even if the filesystem
- * block size is smaller than PAGE_SIZE, we have to zero the rest of the page
- * since the file might be mmapped.
  */
 int dax_truncate_page(struct inode *inode, loff_t from, get_block_t get_block)
 {