userfaultfd: non-cooperative: flush event_wqh at release time
[sfrench/cifs-2.6.git] / fs / userfaultfd.c
index 6148ccd6cccf28f6c21be96bab078b65fcaafe9f..06ea26b8c996f3cc7a9d6fd177260f89394fb325 100644 (file)
@@ -214,6 +214,7 @@ static inline struct uffd_msg userfault_msg(unsigned long address,
  * hugepmd ranges.
  */
 static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
+                                        struct vm_area_struct *vma,
                                         unsigned long address,
                                         unsigned long flags,
                                         unsigned long reason)
@@ -224,7 +225,7 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
 
        VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
 
-       pte = huge_pte_offset(mm, address);
+       pte = huge_pte_offset(mm, address, vma_mmu_pagesize(vma));
        if (!pte)
                goto out;
 
@@ -243,6 +244,7 @@ out:
 }
 #else
 static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
+                                        struct vm_area_struct *vma,
                                         unsigned long address,
                                         unsigned long flags,
                                         unsigned long reason)
@@ -448,7 +450,8 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
                must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
                                                  reason);
        else
-               must_wait = userfaultfd_huge_must_wait(ctx, vmf->address,
+               must_wait = userfaultfd_huge_must_wait(ctx, vmf->vma,
+                                                      vmf->address,
                                                       vmf->flags, reason);
        up_read(&mm->mmap_sem);
 
@@ -851,6 +854,9 @@ wakeup:
        __wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, &range);
        spin_unlock(&ctx->fault_pending_wqh.lock);
 
+       /* Flush pending events that may still wait on event_wqh */
+       wake_up_all(&ctx->event_wqh);
+
        wake_up_poll(&ctx->fd_wqh, POLLHUP);
        userfaultfd_ctx_put(ctx);
        return 0;
@@ -1114,11 +1120,6 @@ static ssize_t userfaultfd_read(struct file *file, char __user *buf,
 static void __wake_userfault(struct userfaultfd_ctx *ctx,
                             struct userfaultfd_wake_range *range)
 {
-       unsigned long start, end;
-
-       start = range->start;
-       end = range->start + range->len;
-
        spin_lock(&ctx->fault_pending_wqh.lock);
        /* wake all in the range and autoremove */
        if (waitqueue_active(&ctx->fault_pending_wqh))
@@ -1645,6 +1646,8 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
                ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start,
                                     uffdio_zeropage.range.len);
                mmput(ctx->mm);
+       } else {
+               return -ENOSPC;
        }
        if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage)))
                return -EFAULT;