Merge tag 'trace-v6.9-2' of git://git.kernel.org/pub/scm/linux/kernel/git/trace/linux...
[sfrench/cifs-2.6.git] / arch / arm64 / crypto / aes-ce-ccm-glue.c
index 25cd3808ecbe6757551964147a7c636ead0fc199..ce9b28e3c7d63462c02827b177dc22633edb5474 100644 (file)
@@ -1,8 +1,11 @@
 // SPDX-License-Identifier: GPL-2.0-only
 /*
- * aes-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
+ * aes-ce-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
  *
- * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2013 - 2017 Linaro Ltd.
+ * Copyright (C) 2024 Google LLC
+ *
+ * Author: Ard Biesheuvel <ardb@kernel.org>
  */
 
 #include <asm/neon.h>
@@ -15,6 +18,8 @@
 
 #include "aes-ce-setkey.h"
 
+MODULE_IMPORT_NS(CRYPTO_INTERNAL);
+
 static int num_rounds(struct crypto_aes_ctx *ctx)
 {
        /*
@@ -27,19 +32,17 @@ static int num_rounds(struct crypto_aes_ctx *ctx)
        return 6 + ctx->key_length / 4;
 }
 
-asmlinkage u32 ce_aes_ccm_auth_data(u8 mac[], u8 const in[], u32 abytes,
-                                   u32 macp, u32 const rk[], u32 rounds);
+asmlinkage u32 ce_aes_mac_update(u8 const in[], u32 const rk[], int rounds,
+                                int blocks, u8 dg[], int enc_before,
+                                int enc_after);
 
 asmlinkage void ce_aes_ccm_encrypt(u8 out[], u8 const in[], u32 cbytes,
                                   u32 const rk[], u32 rounds, u8 mac[],
-                                  u8 ctr[]);
+                                  u8 ctr[], u8 const final_iv[]);
 
 asmlinkage void ce_aes_ccm_decrypt(u8 out[], u8 const in[], u32 cbytes,
                                   u32 const rk[], u32 rounds, u8 mac[],
-                                  u8 ctr[]);
-
-asmlinkage void ce_aes_ccm_final(u8 mac[], u8 const ctr[], u32 const rk[],
-                                u32 rounds);
+                                  u8 ctr[], u8 const final_iv[]);
 
 static int ccm_setkey(struct crypto_aead *tfm, const u8 *in_key,
                      unsigned int key_len)
@@ -94,6 +97,41 @@ static int ccm_init_mac(struct aead_request *req, u8 maciv[], u32 msglen)
        return 0;
 }
 
+static u32 ce_aes_ccm_auth_data(u8 mac[], u8 const in[], u32 abytes,
+                               u32 macp, u32 const rk[], u32 rounds)
+{
+       int enc_after = (macp + abytes) % AES_BLOCK_SIZE;
+
+       do {
+               u32 blocks = abytes / AES_BLOCK_SIZE;
+
+               if (macp == AES_BLOCK_SIZE || (!macp && blocks > 0)) {
+                       u32 rem = ce_aes_mac_update(in, rk, rounds, blocks, mac,
+                                                   macp, enc_after);
+                       u32 adv = (blocks - rem) * AES_BLOCK_SIZE;
+
+                       macp = enc_after ? 0 : AES_BLOCK_SIZE;
+                       in += adv;
+                       abytes -= adv;
+
+                       if (unlikely(rem)) {
+                               kernel_neon_end();
+                               kernel_neon_begin();
+                               macp = 0;
+                       }
+               } else {
+                       u32 l = min(AES_BLOCK_SIZE - macp, abytes);
+
+                       crypto_xor(&mac[macp], in, l);
+                       in += l;
+                       macp += l;
+                       abytes -= l;
+               }
+       } while (abytes > 0);
+
+       return macp;
+}
+
 static void ccm_calculate_auth_mac(struct aead_request *req, u8 mac[])
 {
        struct crypto_aead *aead = crypto_aead_reqtfm(req);
@@ -101,7 +139,7 @@ static void ccm_calculate_auth_mac(struct aead_request *req, u8 mac[])
        struct __packed { __be16 l; __be32 h; u16 len; } ltag;
        struct scatter_walk walk;
        u32 len = req->assoclen;
-       u32 macp = 0;
+       u32 macp = AES_BLOCK_SIZE;
 
        /* prepend the AAD with a length tag */
        if (len < 0xff00) {
@@ -125,16 +163,11 @@ static void ccm_calculate_auth_mac(struct aead_request *req, u8 mac[])
                        scatterwalk_start(&walk, sg_next(walk.sg));
                        n = scatterwalk_clamp(&walk, len);
                }
-               n = min_t(u32, n, SZ_4K); /* yield NEON at least every 4k */
                p = scatterwalk_map(&walk);
 
                macp = ce_aes_ccm_auth_data(mac, p, n, macp, ctx->key_enc,
                                            num_rounds(ctx));
 
-               if (len / SZ_4K > (len - n) / SZ_4K) {
-                       kernel_neon_end();
-                       kernel_neon_begin();
-               }
                len -= n;
 
                scatterwalk_unmap(p);
@@ -149,7 +182,7 @@ static int ccm_encrypt(struct aead_request *req)
        struct crypto_aes_ctx *ctx = crypto_aead_ctx(aead);
        struct skcipher_walk walk;
        u8 __aligned(8) mac[AES_BLOCK_SIZE];
-       u8 buf[AES_BLOCK_SIZE];
+       u8 orig_iv[AES_BLOCK_SIZE];
        u32 len = req->cryptlen;
        int err;
 
@@ -158,42 +191,55 @@ static int ccm_encrypt(struct aead_request *req)
                return err;
 
        /* preserve the original iv for the final round */
-       memcpy(buf, req->iv, AES_BLOCK_SIZE);
+       memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
 
        err = skcipher_walk_aead_encrypt(&walk, req, false);
+       if (unlikely(err))
+               return err;
 
        kernel_neon_begin();
 
        if (req->assoclen)
                ccm_calculate_auth_mac(req, mac);
 
-       while (walk.nbytes) {
+       do {
                u32 tail = walk.nbytes % AES_BLOCK_SIZE;
-               bool final = walk.nbytes == walk.total;
+               const u8 *src = walk.src.virt.addr;
+               u8 *dst = walk.dst.virt.addr;
+               u8 buf[AES_BLOCK_SIZE];
+               u8 *final_iv = NULL;
 
-               if (final)
+               if (walk.nbytes == walk.total) {
                        tail = 0;
+                       final_iv = orig_iv;
+               }
 
-               ce_aes_ccm_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  walk.nbytes - tail, ctx->key_enc,
-                                  num_rounds(ctx), mac, walk.iv);
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
+                                          src, walk.nbytes);
 
-               if (!final)
-                       kernel_neon_end();
-               err = skcipher_walk_done(&walk, tail);
-               if (!final)
-                       kernel_neon_begin();
-       }
+               ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail,
+                                  ctx->key_enc, num_rounds(ctx),
+                                  mac, walk.iv, final_iv);
+
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
-       ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+               if (walk.nbytes) {
+                       err = skcipher_walk_done(&walk, tail);
+               }
+       } while (walk.nbytes);
 
        kernel_neon_end();
 
+       if (unlikely(err))
+               return err;
+
        /* copy authtag to end of dst */
        scatterwalk_map_and_copy(mac, req->dst, req->assoclen + req->cryptlen,
                                 crypto_aead_authsize(aead), 1);
 
-       return err;
+       return 0;
 }
 
 static int ccm_decrypt(struct aead_request *req)
@@ -203,7 +249,7 @@ static int ccm_decrypt(struct aead_request *req)
        unsigned int authsize = crypto_aead_authsize(aead);
        struct skcipher_walk walk;
        u8 __aligned(8) mac[AES_BLOCK_SIZE];
-       u8 buf[AES_BLOCK_SIZE];
+       u8 orig_iv[AES_BLOCK_SIZE];
        u32 len = req->cryptlen - authsize;
        int err;
 
@@ -212,34 +258,44 @@ static int ccm_decrypt(struct aead_request *req)
                return err;
 
        /* preserve the original iv for the final round */
-       memcpy(buf, req->iv, AES_BLOCK_SIZE);
+       memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
 
        err = skcipher_walk_aead_decrypt(&walk, req, false);
+       if (unlikely(err))
+               return err;
 
        kernel_neon_begin();
 
        if (req->assoclen)
                ccm_calculate_auth_mac(req, mac);
 
-       while (walk.nbytes) {
+       do {
                u32 tail = walk.nbytes % AES_BLOCK_SIZE;
-               bool final = walk.nbytes == walk.total;
+               const u8 *src = walk.src.virt.addr;
+               u8 *dst = walk.dst.virt.addr;
+               u8 buf[AES_BLOCK_SIZE];
+               u8 *final_iv = NULL;
 
-               if (final)
+               if (walk.nbytes == walk.total) {
                        tail = 0;
+                       final_iv = orig_iv;
+               }
 
-               ce_aes_ccm_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  walk.nbytes - tail, ctx->key_enc,
-                                  num_rounds(ctx), mac, walk.iv);
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
+                                          src, walk.nbytes);
 
-               if (!final)
-                       kernel_neon_end();
-               err = skcipher_walk_done(&walk, tail);
-               if (!final)
-                       kernel_neon_begin();
-       }
+               ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail,
+                                  ctx->key_enc, num_rounds(ctx),
+                                  mac, walk.iv, final_iv);
+
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
-       ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+               if (walk.nbytes) {
+                       err = skcipher_walk_done(&walk, tail);
+               }
+       } while (walk.nbytes);
 
        kernel_neon_end();
 
@@ -247,11 +303,11 @@ static int ccm_decrypt(struct aead_request *req)
                return err;
 
        /* compare calculated auth tag with the stored one */
-       scatterwalk_map_and_copy(buf, req->src,
+       scatterwalk_map_and_copy(orig_iv, req->src,
                                 req->assoclen + req->cryptlen - authsize,
                                 authsize, 0);
 
-       if (crypto_memneq(mac, buf, authsize))
+       if (crypto_memneq(mac, orig_iv, authsize))
                return -EBADMSG;
        return 0;
 }
@@ -290,6 +346,6 @@ module_init(aes_mod_init);
 module_exit(aes_mod_exit);
 
 MODULE_DESCRIPTION("Synchronous AES in CCM mode using ARMv8 Crypto Extensions");
-MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
+MODULE_AUTHOR("Ard Biesheuvel <ardb@kernel.org>");
 MODULE_LICENSE("GPL v2");
 MODULE_ALIAS_CRYPTO("ccm(aes)");