aio: simplify - and fix - fget/fput for io_submit()
[sfrench/cifs-2.6.git] / fs / aio.c
index aaaaf4d12c7394fe1bb9836516b7b7fc1eb69105..82c08422b0f4f6d8e8fececfba45cfa8f1755446 100644 (file)
--- a/fs/aio.c
+++ b/fs/aio.c
@@ -167,9 +167,13 @@ struct kioctx {
        unsigned                id;
 };
 
+/*
+ * First field must be the file pointer in all the
+ * iocb unions! See also 'struct kiocb' in <linux/fs.h>
+ */
 struct fsync_iocb {
-       struct work_struct      work;
        struct file             *file;
+       struct work_struct      work;
        bool                    datasync;
 };
 
@@ -183,8 +187,15 @@ struct poll_iocb {
        struct work_struct      work;
 };
 
+/*
+ * NOTE! Each of the iocb union members has the file pointer
+ * as the first entry in their struct definition. So you can
+ * access the file pointer through any of the sub-structs,
+ * or directly as just 'ki_filp' in this struct.
+ */
 struct aio_kiocb {
        union {
+               struct file             *ki_filp;
                struct kiocb            rw;
                struct fsync_iocb       fsync;
                struct poll_iocb        poll;
@@ -1060,6 +1071,8 @@ static inline void iocb_put(struct aio_kiocb *iocb)
 {
        if (refcount_read(&iocb->ki_refcnt) == 0 ||
            refcount_dec_and_test(&iocb->ki_refcnt)) {
+               if (iocb->ki_filp)
+                       fput(iocb->ki_filp);
                percpu_ref_put(&iocb->ki_ctx->reqs);
                kmem_cache_free(kiocb_cachep, iocb);
        }
@@ -1424,7 +1437,6 @@ static void aio_complete_rw(struct kiocb *kiocb, long res, long res2)
                file_end_write(kiocb->ki_filp);
        }
 
-       fput(kiocb->ki_filp);
        aio_complete(iocb, res, res2);
 }
 
@@ -1432,9 +1444,6 @@ static int aio_prep_rw(struct kiocb *req, const struct iocb *iocb)
 {
        int ret;
 
-       req->ki_filp = fget(iocb->aio_fildes);
-       if (unlikely(!req->ki_filp))
-               return -EBADF;
        req->ki_complete = aio_complete_rw;
        req->private = NULL;
        req->ki_pos = iocb->aio_offset;
@@ -1451,7 +1460,7 @@ static int aio_prep_rw(struct kiocb *req, const struct iocb *iocb)
                ret = ioprio_check_cap(iocb->aio_reqprio);
                if (ret) {
                        pr_debug("aio ioprio check cap error: %d\n", ret);
-                       goto out_fput;
+                       return ret;
                }
 
                req->ki_ioprio = iocb->aio_reqprio;
@@ -1460,14 +1469,10 @@ static int aio_prep_rw(struct kiocb *req, const struct iocb *iocb)
 
        ret = kiocb_set_rw_flags(req, iocb->aio_rw_flags);
        if (unlikely(ret))
-               goto out_fput;
+               return ret;
 
        req->ki_flags &= ~IOCB_HIPRI; /* no one is going to poll for this I/O */
        return 0;
-
-out_fput:
-       fput(req->ki_filp);
-       return ret;
 }
 
 static int aio_setup_rw(int rw, const struct iocb *iocb, struct iovec **iovec,
@@ -1521,24 +1526,19 @@ static ssize_t aio_read(struct kiocb *req, const struct iocb *iocb,
        if (ret)
                return ret;
        file = req->ki_filp;
-
-       ret = -EBADF;
        if (unlikely(!(file->f_mode & FMODE_READ)))
-               goto out_fput;
+               return -EBADF;
        ret = -EINVAL;
        if (unlikely(!file->f_op->read_iter))
-               goto out_fput;
+               return -EINVAL;
 
        ret = aio_setup_rw(READ, iocb, &iovec, vectored, compat, &iter);
        if (ret)
-               goto out_fput;
+               return ret;
        ret = rw_verify_area(READ, file, &req->ki_pos, iov_iter_count(&iter));
        if (!ret)
                aio_rw_done(req, call_read_iter(file, req, &iter));
        kfree(iovec);
-out_fput:
-       if (unlikely(ret))
-               fput(file);
        return ret;
 }
 
@@ -1555,16 +1555,14 @@ static ssize_t aio_write(struct kiocb *req, const struct iocb *iocb,
                return ret;
        file = req->ki_filp;
 
-       ret = -EBADF;
        if (unlikely(!(file->f_mode & FMODE_WRITE)))
-               goto out_fput;
-       ret = -EINVAL;
+               return -EBADF;
        if (unlikely(!file->f_op->write_iter))
-               goto out_fput;
+               return -EINVAL;
 
        ret = aio_setup_rw(WRITE, iocb, &iovec, vectored, compat, &iter);
        if (ret)
-               goto out_fput;
+               return ret;
        ret = rw_verify_area(WRITE, file, &req->ki_pos, iov_iter_count(&iter));
        if (!ret) {
                /*
@@ -1582,9 +1580,6 @@ static ssize_t aio_write(struct kiocb *req, const struct iocb *iocb,
                aio_rw_done(req, call_write_iter(file, req, &iter));
        }
        kfree(iovec);
-out_fput:
-       if (unlikely(ret))
-               fput(file);
        return ret;
 }
 
@@ -1594,7 +1589,6 @@ static void aio_fsync_work(struct work_struct *work)
        int ret;
 
        ret = vfs_fsync(req->file, req->datasync);
-       fput(req->file);
        aio_complete(container_of(req, struct aio_kiocb, fsync), ret, 0);
 }
 
@@ -1605,13 +1599,8 @@ static int aio_fsync(struct fsync_iocb *req, const struct iocb *iocb,
                        iocb->aio_rw_flags))
                return -EINVAL;
 
-       req->file = fget(iocb->aio_fildes);
-       if (unlikely(!req->file))
-               return -EBADF;
-       if (unlikely(!req->file->f_op->fsync)) {
-               fput(req->file);
+       if (unlikely(!req->file->f_op->fsync))
                return -EINVAL;
-       }
 
        req->datasync = datasync;
        INIT_WORK(&req->work, aio_fsync_work);
@@ -1621,10 +1610,7 @@ static int aio_fsync(struct fsync_iocb *req, const struct iocb *iocb,
 
 static inline void aio_poll_complete(struct aio_kiocb *iocb, __poll_t mask)
 {
-       struct file *file = iocb->poll.file;
-
        aio_complete(iocb, mangle_poll(mask), 0);
-       fput(file);
 }
 
 static void aio_poll_complete_work(struct work_struct *work)
@@ -1743,9 +1729,6 @@ static ssize_t aio_poll(struct aio_kiocb *aiocb, const struct iocb *iocb)
 
        INIT_WORK(&req->work, aio_poll_complete_work);
        req->events = demangle_poll(iocb->aio_buf) | EPOLLERR | EPOLLHUP;
-       req->file = fget(iocb->aio_fildes);
-       if (unlikely(!req->file))
-               return -EBADF;
 
        req->head = NULL;
        req->woken = false;
@@ -1788,10 +1771,8 @@ static ssize_t aio_poll(struct aio_kiocb *aiocb, const struct iocb *iocb)
        spin_unlock_irq(&ctx->ctx_lock);
 
 out:
-       if (unlikely(apt.error)) {
-               fput(req->file);
+       if (unlikely(apt.error))
                return apt.error;
-       }
 
        if (mask)
                aio_poll_complete(aiocb, mask);
@@ -1829,6 +1810,11 @@ static int __io_submit_one(struct kioctx *ctx, const struct iocb *iocb,
        if (unlikely(!req))
                goto out_put_reqs_available;
 
+       req->ki_filp = fget(iocb->aio_fildes);
+       ret = -EBADF;
+       if (unlikely(!req->ki_filp))
+               goto out_put_req;
+
        if (iocb->aio_flags & IOCB_FLAG_RESFD) {
                /*
                 * If the IOCB_FLAG_RESFD flag of aio_flags is set, get an