userfaultfd: use mmgrab instead of open-coded increment of mm_count
[sfrench/cifs-2.6.git] / fs / userfaultfd.c
index 1c713fd5b3e67966c3d998979d2c30eb8e14ba07..ac9a4e65ca497ad3b673dc50893d13b44150c70f 100644 (file)
@@ -381,7 +381,7 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
         * in __get_user_pages if userfaultfd_release waits on the
         * caller of handle_userfault to release the mmap_sem.
         */
-       if (unlikely(ACCESS_ONCE(ctx->released))) {
+       if (unlikely(READ_ONCE(ctx->released))) {
                /*
                 * Don't return VM_FAULT_SIGBUS in this case, so a non
                 * cooperative manager can close the uffd after the
@@ -477,7 +477,7 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
                                                       vmf->flags, reason);
        up_read(&mm->mmap_sem);
 
-       if (likely(must_wait && !ACCESS_ONCE(ctx->released) &&
+       if (likely(must_wait && !READ_ONCE(ctx->released) &&
                   (return_to_userland ? !signal_pending(current) :
                    !fatal_signal_pending(current)))) {
                wake_up_poll(&ctx->fd_wqh, POLLIN);
@@ -586,7 +586,7 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
                set_current_state(TASK_KILLABLE);
                if (ewq->msg.event == 0)
                        break;
-               if (ACCESS_ONCE(ctx->released) ||
+               if (READ_ONCE(ctx->released) ||
                    fatal_signal_pending(current)) {
                        /*
                         * &ewq->wq may be queued in fork_event, but
@@ -668,7 +668,7 @@ int dup_userfaultfd(struct vm_area_struct *vma, struct list_head *fcs)
                ctx->features = octx->features;
                ctx->released = false;
                ctx->mm = vma->vm_mm;
-               atomic_inc(&ctx->mm->mm_count);
+               mmgrab(ctx->mm);
 
                userfaultfd_ctx_get(octx);
                fctx->orig = octx;
@@ -833,7 +833,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
        struct userfaultfd_wake_range range = { .len = 0, };
        unsigned long new_flags;
 
-       ACCESS_ONCE(ctx->released) = true;
+       WRITE_ONCE(ctx->released, true);
 
        if (!mmget_not_zero(mm))
                goto wakeup;