IB/uverbs: Allow all DESTROY commands to succeed after disassociate
[sfrench/cifs-2.6.git] / fs / userfaultfd.c
index cec550c8468f484a3f14d6a6b3e5dcc2a09ea70f..123bf7d516fc1f475cb89edb8aade4c2ad556f51 100644 (file)
@@ -62,6 +62,8 @@ struct userfaultfd_ctx {
        enum userfaultfd_state state;
        /* released */
        bool released;
+       /* memory mappings are changing because of non-cooperative event */
+       bool mmap_changing;
        /* mm with one ore more vmas attached to this userfaultfd_ctx */
        struct mm_struct *mm;
 };
@@ -641,6 +643,7 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
         * already released.
         */
 out:
+       WRITE_ONCE(ctx->mmap_changing, false);
        userfaultfd_ctx_put(ctx);
 }
 
@@ -686,10 +689,12 @@ int dup_userfaultfd(struct vm_area_struct *vma, struct list_head *fcs)
                ctx->state = UFFD_STATE_RUNNING;
                ctx->features = octx->features;
                ctx->released = false;
+               ctx->mmap_changing = false;
                ctx->mm = vma->vm_mm;
                mmgrab(ctx->mm);
 
                userfaultfd_ctx_get(octx);
+               WRITE_ONCE(octx->mmap_changing, true);
                fctx->orig = octx;
                fctx->new = ctx;
                list_add_tail(&fctx->list, fcs);
@@ -732,6 +737,7 @@ void mremap_userfaultfd_prep(struct vm_area_struct *vma,
        if (ctx && (ctx->features & UFFD_FEATURE_EVENT_REMAP)) {
                vm_ctx->ctx = ctx;
                userfaultfd_ctx_get(ctx);
+               WRITE_ONCE(ctx->mmap_changing, true);
        }
 }
 
@@ -772,6 +778,7 @@ bool userfaultfd_remove(struct vm_area_struct *vma,
                return true;
 
        userfaultfd_ctx_get(ctx);
+       WRITE_ONCE(ctx->mmap_changing, true);
        up_read(&mm->mmap_sem);
 
        msg_init(&ewq.msg);
@@ -815,6 +822,7 @@ int userfaultfd_unmap_prep(struct vm_area_struct *vma,
                        return -ENOMEM;
 
                userfaultfd_ctx_get(ctx);
+               WRITE_ONCE(ctx->mmap_changing, true);
                unmap_ctx->ctx = ctx;
                unmap_ctx->start = start;
                unmap_ctx->end = end;
@@ -1653,6 +1661,10 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
 
        user_uffdio_copy = (struct uffdio_copy __user *) arg;
 
+       ret = -EAGAIN;
+       if (READ_ONCE(ctx->mmap_changing))
+               goto out;
+
        ret = -EFAULT;
        if (copy_from_user(&uffdio_copy, user_uffdio_copy,
                           /* don't copy "copy" last field */
@@ -1674,7 +1686,7 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
                goto out;
        if (mmget_not_zero(ctx->mm)) {
                ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
-                                  uffdio_copy.len);
+                                  uffdio_copy.len, &ctx->mmap_changing);
                mmput(ctx->mm);
        } else {
                return -ESRCH;
@@ -1705,6 +1717,10 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
 
        user_uffdio_zeropage = (struct uffdio_zeropage __user *) arg;
 
+       ret = -EAGAIN;
+       if (READ_ONCE(ctx->mmap_changing))
+               goto out;
+
        ret = -EFAULT;
        if (copy_from_user(&uffdio_zeropage, user_uffdio_zeropage,
                           /* don't copy "zeropage" last field */
@@ -1721,7 +1737,8 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
 
        if (mmget_not_zero(ctx->mm)) {
                ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start,
-                                    uffdio_zeropage.range.len);
+                                    uffdio_zeropage.range.len,
+                                    &ctx->mmap_changing);
                mmput(ctx->mm);
        } else {
                return -ESRCH;
@@ -1900,6 +1917,7 @@ SYSCALL_DEFINE1(userfaultfd, int, flags)
        ctx->features = 0;
        ctx->state = UFFD_STATE_WAIT_API;
        ctx->released = false;
+       ctx->mmap_changing = false;
        ctx->mm = current->mm;
        /* prevent the mm struct to be freed */
        mmgrab(ctx->mm);