io_uring: move all prep state for IORING_OP_{SEND,RECV}_MGS to prep handler
authorJens Axboe <axboe@kernel.dk>
Fri, 20 Dec 2019 15:58:21 +0000 (08:58 -0700)
committerJens Axboe <axboe@kernel.dk>
Fri, 20 Dec 2019 16:55:23 +0000 (09:55 -0700)
Add struct io_sr_msg in our io_kiocb per-command union, and ensure that
the send/recvmsg prep handlers have grabbed what they need from the SQE
by the time prep is done.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index 2a173f54ec8ea42d7fe36fdfcbea4bf4c34ab7e3..89e5b19044cc7f267c1db16bbb010ed322b784a1 100644 (file)
@@ -345,6 +345,12 @@ struct io_connect {
        int                             addr_len;
 };
 
+struct io_sr_msg {
+       struct file                     *file;
+       struct user_msghdr __user       *msg;
+       int                             msg_flags;
+};
+
 struct io_async_connect {
        struct sockaddr_storage         address;
 };
@@ -389,6 +395,7 @@ struct io_kiocb {
                struct io_cancel        cancel;
                struct io_timeout       timeout;
                struct io_connect       connect;
+               struct io_sr_msg        sr_msg;
        };
 
        const struct io_uring_sqe       *sqe;
@@ -2164,15 +2171,15 @@ static int io_sendmsg_prep(struct io_kiocb *req, struct io_async_ctx *io)
 {
 #if defined(CONFIG_NET)
        const struct io_uring_sqe *sqe = req->sqe;
-       struct user_msghdr __user *msg;
-       unsigned flags;
+       struct io_sr_msg *sr = &req->sr_msg;
 
-       flags = READ_ONCE(sqe->msg_flags);
-       msg = u64_to_user_ptr(READ_ONCE(sqe->addr));
+       sr->msg_flags = READ_ONCE(sqe->msg_flags);
+       sr->msg = u64_to_user_ptr(READ_ONCE(sqe->addr));
        io->msg.iov = io->msg.fast_iov;
-       return sendmsg_copy_msghdr(&io->msg.msg, msg, flags, &io->msg.iov);
+       return sendmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
+                                       &io->msg.iov);
 #else
-       return 0;
+       return -EOPNOTSUPP;
 #endif
 }
 
@@ -2180,7 +2187,6 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                      bool force_nonblock)
 {
 #if defined(CONFIG_NET)
-       const struct io_uring_sqe *sqe = req->sqe;
        struct io_async_msghdr *kmsg = NULL;
        struct socket *sock;
        int ret;
@@ -2194,12 +2200,6 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                struct sockaddr_storage addr;
                unsigned flags;
 
-               flags = READ_ONCE(sqe->msg_flags);
-               if (flags & MSG_DONTWAIT)
-                       req->flags |= REQ_F_NOWAIT;
-               else if (force_nonblock)
-                       flags |= MSG_DONTWAIT;
-
                if (req->io) {
                        kmsg = &req->io->msg;
                        kmsg->msg.msg_name = &addr;
@@ -2215,6 +2215,12 @@ static int io_sendmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                                goto out;
                }
 
+               flags = req->sr_msg.msg_flags;
+               if (flags & MSG_DONTWAIT)
+                       req->flags |= REQ_F_NOWAIT;
+               else if (force_nonblock)
+                       flags |= MSG_DONTWAIT;
+
                ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
                if (force_nonblock && ret == -EAGAIN) {
                        if (req->io)
@@ -2245,17 +2251,15 @@ out:
 static int io_recvmsg_prep(struct io_kiocb *req, struct io_async_ctx *io)
 {
 #if defined(CONFIG_NET)
-       const struct io_uring_sqe *sqe = req->sqe;
-       struct user_msghdr __user *msg;
-       unsigned flags;
+       struct io_sr_msg *sr = &req->sr_msg;
 
-       flags = READ_ONCE(sqe->msg_flags);
-       msg = u64_to_user_ptr(READ_ONCE(sqe->addr));
+       sr->msg_flags = READ_ONCE(req->sqe->msg_flags);
+       sr->msg = u64_to_user_ptr(READ_ONCE(req->sqe->addr));
        io->msg.iov = io->msg.fast_iov;
-       return recvmsg_copy_msghdr(&io->msg.msg, msg, flags, &io->msg.uaddr,
-                                       &io->msg.iov);
+       return recvmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
+                                       &io->msg.uaddr, &io->msg.iov);
 #else
-       return 0;
+       return -EOPNOTSUPP;
 #endif
 }
 
@@ -2263,7 +2267,6 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                      bool force_nonblock)
 {
 #if defined(CONFIG_NET)
-       const struct io_uring_sqe *sqe = req->sqe;
        struct io_async_msghdr *kmsg = NULL;
        struct socket *sock;
        int ret;
@@ -2273,18 +2276,10 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
 
        sock = sock_from_file(req->file, &ret);
        if (sock) {
-               struct user_msghdr __user *msg;
                struct io_async_ctx io;
                struct sockaddr_storage addr;
                unsigned flags;
 
-               flags = READ_ONCE(sqe->msg_flags);
-               if (flags & MSG_DONTWAIT)
-                       req->flags |= REQ_F_NOWAIT;
-               else if (force_nonblock)
-                       flags |= MSG_DONTWAIT;
-
-               msg = u64_to_user_ptr(READ_ONCE(sqe->addr));
                if (req->io) {
                        kmsg = &req->io->msg;
                        kmsg->msg.msg_name = &addr;
@@ -2300,7 +2295,14 @@ static int io_recvmsg(struct io_kiocb *req, struct io_kiocb **nxt,
                                goto out;
                }
 
-               ret = __sys_recvmsg_sock(sock, &kmsg->msg, msg, kmsg->uaddr, flags);
+               flags = req->sr_msg.msg_flags;
+               if (flags & MSG_DONTWAIT)
+                       req->flags |= REQ_F_NOWAIT;
+               else if (force_nonblock)
+                       flags |= MSG_DONTWAIT;
+
+               ret = __sys_recvmsg_sock(sock, &kmsg->msg, req->sr_msg.msg,
+                                               kmsg->uaddr, flags);
                if (force_nonblock && ret == -EAGAIN) {
                        if (req->io)
                                return -EAGAIN;