Merge tag 'docs-4.10-2' of git://git.lwn.net/linux
[sfrench/cifs-2.6.git] / crypto / algif_aead.c
index a0d8377729a4fd5ac33cfb70fa57d6340222b81c..f849311e9fd4c94e57d81ba97279ec5fb0cb0ded 100644 (file)
@@ -81,7 +81,11 @@ static inline bool aead_sufficient_data(struct aead_ctx *ctx)
 {
        unsigned as = crypto_aead_authsize(crypto_aead_reqtfm(&ctx->aead_req));
 
-       return ctx->used >= ctx->aead_assoclen + as;
+       /*
+        * The minimum amount of memory needed for an AEAD cipher is
+        * the AAD and in case of decryption the tag.
+        */
+       return ctx->used >= ctx->aead_assoclen + (ctx->enc ? 0 : as);
 }
 
 static void aead_reset_ctx(struct aead_ctx *ctx)
@@ -132,28 +136,27 @@ static void aead_wmem_wakeup(struct sock *sk)
 
 static int aead_wait_for_data(struct sock *sk, unsigned flags)
 {
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
        struct alg_sock *ask = alg_sk(sk);
        struct aead_ctx *ctx = ask->private;
        long timeout;
-       DEFINE_WAIT(wait);
        int err = -ERESTARTSYS;
 
        if (flags & MSG_DONTWAIT)
                return -EAGAIN;
 
        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-
+       add_wait_queue(sk_sleep(sk), &wait);
        for (;;) {
                if (signal_pending(current))
                        break;
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                timeout = MAX_SCHEDULE_TIMEOUT;
-               if (sk_wait_event(sk, &timeout, !ctx->more)) {
+               if (sk_wait_event(sk, &timeout, !ctx->more, &wait)) {
                        err = 0;
                        break;
                }
        }
-       finish_wait(sk_sleep(sk), &wait);
+       remove_wait_queue(sk_sleep(sk), &wait);
 
        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 
@@ -416,7 +419,7 @@ static int aead_recvmsg_async(struct socket *sock, struct msghdr *msg,
        unsigned int i, reqlen = GET_REQ_SIZE(tfm);
        int err = -ENOMEM;
        unsigned long used;
-       size_t outlen;
+       size_t outlen = 0;
        size_t usedpages = 0;
 
        lock_sock(sk);
@@ -426,12 +429,15 @@ static int aead_recvmsg_async(struct socket *sock, struct msghdr *msg,
                        goto unlock;
        }
 
-       used = ctx->used;
-       outlen = used;
-
        if (!aead_sufficient_data(ctx))
                goto unlock;
 
+       used = ctx->used;
+       if (ctx->enc)
+               outlen = used + as;
+       else
+               outlen = used - as;
+
        req = sock_kmalloc(sk, reqlen, GFP_KERNEL);
        if (unlikely(!req))
                goto unlock;
@@ -445,15 +451,16 @@ static int aead_recvmsg_async(struct socket *sock, struct msghdr *msg,
        aead_request_set_ad(req, ctx->aead_assoclen);
        aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
                                  aead_async_cb, sk);
-       used -= ctx->aead_assoclen + (ctx->enc ? as : 0);
+       used -= ctx->aead_assoclen;
 
        /* take over all tx sgls from ctx */
-       areq->tsgl = sock_kmalloc(sk, sizeof(*areq->tsgl) * sgl->cur,
+       areq->tsgl = sock_kmalloc(sk,
+                                 sizeof(*areq->tsgl) * max_t(u32, sgl->cur, 1),
                                  GFP_KERNEL);
        if (unlikely(!areq->tsgl))
                goto free;
 
-       sg_init_table(areq->tsgl, sgl->cur);
+       sg_init_table(areq->tsgl, max_t(u32, sgl->cur, 1));
        for (i = 0; i < sgl->cur; i++)
                sg_set_page(&areq->tsgl[i], sg_page(&sgl->sg[i]),
                            sgl->sg[i].length, sgl->sg[i].offset);
@@ -461,7 +468,7 @@ static int aead_recvmsg_async(struct socket *sock, struct msghdr *msg,
        areq->tsgls = sgl->cur;
 
        /* create rx sgls */
-       while (iov_iter_count(&msg->msg_iter)) {
+       while (outlen > usedpages && iov_iter_count(&msg->msg_iter)) {
                size_t seglen = min_t(size_t, iov_iter_count(&msg->msg_iter),
                                      (outlen - usedpages));
 
@@ -491,16 +498,14 @@ static int aead_recvmsg_async(struct socket *sock, struct msghdr *msg,
 
                last_rsgl = rsgl;
 
-               /* we do not need more iovecs as we have sufficient memory */
-               if (outlen <= usedpages)
-                       break;
-
                iov_iter_advance(&msg->msg_iter, err);
        }
-       err = -EINVAL;
+
        /* ensure output buffer is sufficiently large */
-       if (usedpages < outlen)
-               goto free;
+       if (usedpages < outlen) {
+               err = -EINVAL;
+               goto unlock;
+       }
 
        aead_request_set_crypt(req, areq->tsgl, areq->first_rsgl.sgl.sg, used,
                               areq->iv);
@@ -561,6 +566,7 @@ static int aead_recvmsg_sync(struct socket *sock, struct msghdr *msg, int flags)
                        goto unlock;
        }
 
+       /* data length provided by caller via sendmsg/sendpage */
        used = ctx->used;
 
        /*
@@ -575,16 +581,27 @@ static int aead_recvmsg_sync(struct socket *sock, struct msghdr *msg, int flags)
        if (!aead_sufficient_data(ctx))
                goto unlock;
 
-       outlen = used;
+       /*
+        * Calculate the minimum output buffer size holding the result of the
+        * cipher operation. When encrypting data, the receiving buffer is
+        * larger by the tag length compared to the input buffer as the
+        * encryption operation generates the tag. For decryption, the input
+        * buffer provides the tag which is consumed resulting in only the
+        * plaintext without a buffer for the tag returned to the caller.
+        */
+       if (ctx->enc)
+               outlen = used + as;
+       else
+               outlen = used - as;
 
        /*
         * The cipher operation input data is reduced by the associated data
         * length as this data is processed separately later on.
         */
-       used -= ctx->aead_assoclen + (ctx->enc ? as : 0);
+       used -= ctx->aead_assoclen;
 
        /* convert iovecs of output buffers into scatterlists */
-       while (iov_iter_count(&msg->msg_iter)) {
+       while (outlen > usedpages && iov_iter_count(&msg->msg_iter)) {
                size_t seglen = min_t(size_t, iov_iter_count(&msg->msg_iter),
                                      (outlen - usedpages));
 
@@ -611,16 +628,14 @@ static int aead_recvmsg_sync(struct socket *sock, struct msghdr *msg, int flags)
 
                last_rsgl = rsgl;
 
-               /* we do not need more iovecs as we have sufficient memory */
-               if (outlen <= usedpages)
-                       break;
                iov_iter_advance(&msg->msg_iter, err);
        }
 
-       err = -EINVAL;
        /* ensure output buffer is sufficiently large */
-       if (usedpages < outlen)
+       if (usedpages < outlen) {
+               err = -EINVAL;
                goto unlock;
+       }
 
        sg_mark_end(sgl->sg + sgl->cur - 1);
        aead_request_set_crypt(&ctx->aead_req, sgl->sg, ctx->first_rsgl.sgl.sg,