io_uring: make IORING_POLL_ADD and IORING_POLL_REMOVE deferrable
authorJens Axboe <axboe@kernel.dk>
Wed, 18 Dec 2019 01:40:57 +0000 (18:40 -0700)
committerJens Axboe <axboe@kernel.dk>
Wed, 18 Dec 2019 02:57:27 +0000 (19:57 -0700)
If we defer these commands as part of a link, we have to make sure that
the SQE data has been read upfront. Integrate the poll add/remove into
the prep handling to make it safe for SQE reuse.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index b476bd304045ca82e47329f05d50ee1dbab9177e..b0411406c50a76dda2b3cc57afc729bca337023d 100644 (file)
@@ -289,7 +289,10 @@ struct io_ring_ctx {
  */
 struct io_poll_iocb {
        struct file                     *file;
-       struct wait_queue_head          *head;
+       union {
+               struct wait_queue_head  *head;
+               u64                     addr;
+       };
        __poll_t                        events;
        bool                            done;
        bool                            canceled;
@@ -2490,24 +2493,40 @@ static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr)
        return -ENOENT;
 }
 
+static int io_poll_remove_prep(struct io_kiocb *req)
+{
+       const struct io_uring_sqe *sqe = req->sqe;
+
+       if (req->flags & REQ_F_PREPPED)
+               return 0;
+       if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
+               return -EINVAL;
+       if (sqe->ioprio || sqe->off || sqe->len || sqe->buf_index ||
+           sqe->poll_events)
+               return -EINVAL;
+
+       req->poll.addr = READ_ONCE(sqe->addr);
+       req->flags |= REQ_F_PREPPED;
+       return 0;
+}
+
 /*
  * Find a running poll command that matches one specified in sqe->addr,
  * and remove it if found.
  */
 static int io_poll_remove(struct io_kiocb *req)
 {
-       const struct io_uring_sqe *sqe = req->sqe;
        struct io_ring_ctx *ctx = req->ctx;
+       u64 addr;
        int ret;
 
-       if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
-               return -EINVAL;
-       if (sqe->ioprio || sqe->off || sqe->len || sqe->buf_index ||
-           sqe->poll_events)
-               return -EINVAL;
+       ret = io_poll_remove_prep(req);
+       if (ret)
+               return ret;
 
+       addr = req->poll.addr;
        spin_lock_irq(&ctx->completion_lock);
-       ret = io_poll_cancel(ctx, READ_ONCE(sqe->addr));
+       ret = io_poll_cancel(ctx, addr);
        spin_unlock_irq(&ctx->completion_lock);
 
        io_cqring_add_event(req, ret);
@@ -2642,16 +2661,14 @@ static void io_poll_req_insert(struct io_kiocb *req)
        hlist_add_head(&req->hash_node, list);
 }
 
-static int io_poll_add(struct io_kiocb *req, struct io_kiocb **nxt)
+static int io_poll_add_prep(struct io_kiocb *req)
 {
        const struct io_uring_sqe *sqe = req->sqe;
        struct io_poll_iocb *poll = &req->poll;
-       struct io_ring_ctx *ctx = req->ctx;
-       struct io_poll_table ipt;
-       bool cancel = false;
-       __poll_t mask;
        u16 events;
 
+       if (req->flags & REQ_F_PREPPED)
+               return 0;
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
        if (sqe->addr || sqe->ioprio || sqe->off || sqe->len || sqe->buf_index)
@@ -2659,9 +2676,26 @@ static int io_poll_add(struct io_kiocb *req, struct io_kiocb **nxt)
        if (!poll->file)
                return -EBADF;
 
-       INIT_IO_WORK(&req->work, io_poll_complete_work);
+       req->flags |= REQ_F_PREPPED;
        events = READ_ONCE(sqe->poll_events);
        poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP;
+       return 0;
+}
+
+static int io_poll_add(struct io_kiocb *req, struct io_kiocb **nxt)
+{
+       struct io_poll_iocb *poll = &req->poll;
+       struct io_ring_ctx *ctx = req->ctx;
+       struct io_poll_table ipt;
+       bool cancel = false;
+       __poll_t mask;
+       int ret;
+
+       ret = io_poll_add_prep(req);
+       if (ret)
+               return ret;
+
+       INIT_IO_WORK(&req->work, io_poll_complete_work);
        INIT_HLIST_NODE(&req->hash_node);
 
        poll->head = NULL;
@@ -3029,6 +3063,12 @@ static int io_req_defer_prep(struct io_kiocb *req)
                io_req_map_rw(req, ret, iovec, inline_vecs, &iter);
                ret = 0;
                break;
+       case IORING_OP_POLL_ADD:
+               ret = io_poll_add_prep(req);
+               break;
+       case IORING_OP_POLL_REMOVE:
+               ret = io_poll_remove_prep(req);
+               break;
        case IORING_OP_FSYNC:
                ret = io_prep_fsync(req);
                break;