userfaultfd: non-cooperative: flush event_wqh at release time
[sfrench/cifs-2.6.git] / fs / userfaultfd.c
index 1d622f276e3a2c0fba23d979d7ae18cf6a899f30..06ea26b8c996f3cc7a9d6fd177260f89394fb325 100644 (file)
@@ -81,7 +81,7 @@ struct userfaultfd_unmap_ctx {
 
 struct userfaultfd_wait_queue {
        struct uffd_msg msg;
-       wait_queue_t wq;
+       wait_queue_entry_t wq;
        struct userfaultfd_ctx *ctx;
        bool waken;
 };
@@ -91,7 +91,7 @@ struct userfaultfd_wake_range {
        unsigned long len;
 };
 
-static int userfaultfd_wake_function(wait_queue_t *wq, unsigned mode,
+static int userfaultfd_wake_function(wait_queue_entry_t *wq, unsigned mode,
                                     int wake_flags, void *key)
 {
        struct userfaultfd_wake_range *range = key;
@@ -129,7 +129,7 @@ static int userfaultfd_wake_function(wait_queue_t *wq, unsigned mode,
                 * wouldn't be enough, the smp_mb__before_spinlock is
                 * enough to avoid an explicit smp_mb() here.
                 */
-               list_del_init(&wq->task_list);
+               list_del_init(&wq->entry);
 out:
        return ret;
 }
@@ -214,6 +214,7 @@ static inline struct uffd_msg userfault_msg(unsigned long address,
  * hugepmd ranges.
  */
 static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
+                                        struct vm_area_struct *vma,
                                         unsigned long address,
                                         unsigned long flags,
                                         unsigned long reason)
@@ -224,7 +225,7 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
 
        VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
 
-       pte = huge_pte_offset(mm, address);
+       pte = huge_pte_offset(mm, address, vma_mmu_pagesize(vma));
        if (!pte)
                goto out;
 
@@ -243,6 +244,7 @@ out:
 }
 #else
 static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
+                                        struct vm_area_struct *vma,
                                         unsigned long address,
                                         unsigned long flags,
                                         unsigned long reason)
@@ -448,7 +450,8 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
                must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
                                                  reason);
        else
-               must_wait = userfaultfd_huge_must_wait(ctx, vmf->address,
+               must_wait = userfaultfd_huge_must_wait(ctx, vmf->vma,
+                                                      vmf->address,
                                                       vmf->flags, reason);
        up_read(&mm->mmap_sem);
 
@@ -522,13 +525,13 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
         * and it's fine not to block on the spinlock. The uwq on this
         * kernel stack can be released after the list_del_init.
         */
-       if (!list_empty_careful(&uwq.wq.task_list)) {
+       if (!list_empty_careful(&uwq.wq.entry)) {
                spin_lock(&ctx->fault_pending_wqh.lock);
                /*
                 * No need of list_del_init(), the uwq on the stack
                 * will be freed shortly anyway.
                 */
-               list_del(&uwq.wq.task_list);
+               list_del(&uwq.wq.entry);
                spin_unlock(&ctx->fault_pending_wqh.lock);
        }
 
@@ -851,6 +854,9 @@ wakeup:
        __wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, &range);
        spin_unlock(&ctx->fault_pending_wqh.lock);
 
+       /* Flush pending events that may still wait on event_wqh */
+       wake_up_all(&ctx->event_wqh);
+
        wake_up_poll(&ctx->fd_wqh, POLLHUP);
        userfaultfd_ctx_put(ctx);
        return 0;
@@ -860,7 +866,7 @@ wakeup:
 static inline struct userfaultfd_wait_queue *find_userfault_in(
                wait_queue_head_t *wqh)
 {
-       wait_queue_t *wq;
+       wait_queue_entry_t *wq;
        struct userfaultfd_wait_queue *uwq;
 
        VM_BUG_ON(!spin_is_locked(&wqh->lock));
@@ -869,7 +875,7 @@ static inline struct userfaultfd_wait_queue *find_userfault_in(
        if (!waitqueue_active(wqh))
                goto out;
        /* walk in reverse to provide FIFO behavior to read userfaults */
-       wq = list_last_entry(&wqh->task_list, typeof(*wq), task_list);
+       wq = list_last_entry(&wqh->head, typeof(*wq), entry);
        uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
 out:
        return uwq;
@@ -1003,14 +1009,14 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
                         * changes __remove_wait_queue() to use
                         * list_del_init() in turn breaking the
                         * !list_empty_careful() check in
-                        * handle_userfault(). The uwq->wq.task_list
+                        * handle_userfault(). The uwq->wq.head list
                         * must never be empty at any time during the
                         * refile, or the waitqueue could disappear
                         * from under us. The "wait_queue_head_t"
                         * parameter of __remove_wait_queue() is unused
                         * anyway.
                         */
-                       list_del(&uwq->wq.task_list);
+                       list_del(&uwq->wq.entry);
                        __add_wait_queue(&ctx->fault_wqh, &uwq->wq);
 
                        write_seqcount_end(&ctx->refile_seq);
@@ -1032,7 +1038,7 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
                                fork_nctx = (struct userfaultfd_ctx *)
                                        (unsigned long)
                                        uwq->msg.arg.reserved.reserved1;
-                               list_move(&uwq->wq.task_list, &fork_event);
+                               list_move(&uwq->wq.entry, &fork_event);
                                spin_unlock(&ctx->event_wqh.lock);
                                ret = 0;
                                break;
@@ -1069,8 +1075,8 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
                        if (!list_empty(&fork_event)) {
                                uwq = list_first_entry(&fork_event,
                                                       typeof(*uwq),
-                                                      wq.task_list);
-                               list_del(&uwq->wq.task_list);
+                                                      wq.entry);
+                               list_del(&uwq->wq.entry);
                                __add_wait_queue(&ctx->event_wqh, &uwq->wq);
                                userfaultfd_event_complete(ctx, uwq);
                        }
@@ -1114,11 +1120,6 @@ static ssize_t userfaultfd_read(struct file *file, char __user *buf,
 static void __wake_userfault(struct userfaultfd_ctx *ctx,
                             struct userfaultfd_wake_range *range)
 {
-       unsigned long start, end;
-
-       start = range->start;
-       end = range->start + range->len;
-
        spin_lock(&ctx->fault_pending_wqh.lock);
        /* wake all in the range and autoremove */
        if (waitqueue_active(&ctx->fault_pending_wqh))
@@ -1645,6 +1646,8 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
                ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start,
                                     uffdio_zeropage.range.len);
                mmput(ctx->mm);
+       } else {
+               return -ENOSPC;
        }
        if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage)))
                return -EFAULT;
@@ -1747,17 +1750,17 @@ static long userfaultfd_ioctl(struct file *file, unsigned cmd,
 static void userfaultfd_show_fdinfo(struct seq_file *m, struct file *f)
 {
        struct userfaultfd_ctx *ctx = f->private_data;
-       wait_queue_t *wq;
+       wait_queue_entry_t *wq;
        struct userfaultfd_wait_queue *uwq;
        unsigned long pending = 0, total = 0;
 
        spin_lock(&ctx->fault_pending_wqh.lock);
-       list_for_each_entry(wq, &ctx->fault_pending_wqh.task_list, task_list) {
+       list_for_each_entry(wq, &ctx->fault_pending_wqh.head, entry) {
                uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
                pending++;
                total++;
        }
-       list_for_each_entry(wq, &ctx->fault_wqh.task_list, task_list) {
+       list_for_each_entry(wq, &ctx->fault_wqh.head, entry) {
                uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
                total++;
        }