crypto: arm64/aes-ce-gcm - operate on two input blocks at a time
authorArd Biesheuvel <ard.biesheuvel@linaro.org>
Mon, 30 Jul 2018 21:06:40 +0000 (23:06 +0200)
committerHerbert Xu <herbert@gondor.apana.org.au>
Tue, 7 Aug 2018 09:38:04 +0000 (17:38 +0800)
Update the core AES/GCM transform and the associated plumbing to operate
on 2 AES/GHASH blocks at a time. By itself, this is not expected to
result in a noticeable speedup, but it paves the way for reimplementing
the GHASH component using 2-way aggregation.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/arm64/crypto/ghash-ce-core.S
arch/arm64/crypto/ghash-ce-glue.c

index c723647b37db0387f58d3ea88f899147fdbc2727..dac0df29d19485cb279b9fc728f8d3243718b1dd 100644 (file)
@@ -286,9 +286,10 @@ ENTRY(pmull_ghash_update_p8)
        __pmull_ghash   p8
 ENDPROC(pmull_ghash_update_p8)
 
-       KS              .req    v8
-       CTR             .req    v9
-       INP             .req    v10
+       KS0             .req    v8
+       KS1             .req    v9
+       INP0            .req    v10
+       INP1            .req    v11
 
        .macro          load_round_keys, rounds, rk
        cmp             \rounds, #12
@@ -336,84 +337,146 @@ CPU_LE(  rev             x8, x8          )
 
        .if             \enc == 1
        ldr             x10, [sp]
-       ld1             {KS.16b}, [x10]
+       ld1             {KS0.16b-KS1.16b}, [x10]
        .endif
 
-0:     ld1             {CTR.8b}, [x5]                  // load upper counter
-       ld1             {INP.16b}, [x3], #16
+0:     ld1             {INP0.16b-INP1.16b}, [x3], #32
+
        rev             x9, x8
-       add             x8, x8, #1
-       sub             w0, w0, #1
-       ins             CTR.d[1], x9                    // set lower counter
+       add             x11, x8, #1
+       add             x8, x8, #2
 
        .if             \enc == 1
-       eor             INP.16b, INP.16b, KS.16b        // encrypt input
-       st1             {INP.16b}, [x2], #16
+       eor             INP0.16b, INP0.16b, KS0.16b     // encrypt input
+       eor             INP1.16b, INP1.16b, KS1.16b
        .endif
 
-       rev64           T1.16b, INP.16b
+       ld1             {KS0.8b}, [x5]                  // load upper counter
+       rev             x11, x11
+       sub             w0, w0, #2
+       mov             KS1.8b, KS0.8b
+       ins             KS0.d[1], x9                    // set lower counter
+       ins             KS1.d[1], x11
+
+       rev64           T1.16b, INP0.16b
 
        cmp             w7, #12
        b.ge            2f                              // AES-192/256?
 
-1:     enc_round       CTR, v21
+1:     enc_round       KS0, v21
 
        ext             T2.16b, XL.16b, XL.16b, #8
        ext             IN1.16b, T1.16b, T1.16b, #8
 
-       enc_round       CTR, v22
+       enc_round       KS1, v21
 
        eor             T1.16b, T1.16b, T2.16b
        eor             XL.16b, XL.16b, IN1.16b
 
-       enc_round       CTR, v23
+       enc_round       KS0, v22
 
        pmull2          XH.1q, SHASH.2d, XL.2d          // a1 * b1
        eor             T1.16b, T1.16b, XL.16b
 
-       enc_round       CTR, v24
+       enc_round       KS1, v22
 
        pmull           XL.1q, SHASH.1d, XL.1d          // a0 * b0
        pmull           XM.1q, SHASH2.1d, T1.1d         // (a1 + a0)(b1 + b0)
 
-       enc_round       CTR, v25
+       enc_round       KS0, v23
 
        ext             T1.16b, XL.16b, XH.16b, #8
        eor             T2.16b, XL.16b, XH.16b
        eor             XM.16b, XM.16b, T1.16b
 
-       enc_round       CTR, v26
+       enc_round       KS1, v23
 
        eor             XM.16b, XM.16b, T2.16b
        pmull           T2.1q, XL.1d, MASK.1d
 
-       enc_round       CTR, v27
+       enc_round       KS0, v24
 
        mov             XH.d[0], XM.d[1]
        mov             XM.d[1], XL.d[0]
 
-       enc_round       CTR, v28
+       enc_round       KS1, v24
 
        eor             XL.16b, XM.16b, T2.16b
 
-       enc_round       CTR, v29
+       enc_round       KS0, v25
 
        ext             T2.16b, XL.16b, XL.16b, #8
 
-       aese            CTR.16b, v30.16b
+       enc_round       KS1, v25
 
        pmull           XL.1q, XL.1d, MASK.1d
        eor             T2.16b, T2.16b, XH.16b
 
-       eor             KS.16b, CTR.16b, v31.16b
+       enc_round       KS0, v26
+
+       eor             XL.16b, XL.16b, T2.16b
+       rev64           T1.16b, INP1.16b
+
+       enc_round       KS1, v26
+
+       ext             T2.16b, XL.16b, XL.16b, #8
+       ext             IN1.16b, T1.16b, T1.16b, #8
+
+       enc_round       KS0, v27
+
+       eor             T1.16b, T1.16b, T2.16b
+       eor             XL.16b, XL.16b, IN1.16b
+
+       enc_round       KS1, v27
+
+       pmull2          XH.1q, SHASH.2d, XL.2d          // a1 * b1
+       eor             T1.16b, T1.16b, XL.16b
+
+       enc_round       KS0, v28
+
+       pmull           XL.1q, SHASH.1d, XL.1d          // a0 * b0
+       pmull           XM.1q, SHASH2.1d, T1.1d         // (a1 + a0)(b1 + b0)
+
+       enc_round       KS1, v28
+
+       ext             T1.16b, XL.16b, XH.16b, #8
+       eor             T2.16b, XL.16b, XH.16b
+       eor             XM.16b, XM.16b, T1.16b
+
+       enc_round       KS0, v29
+
+       eor             XM.16b, XM.16b, T2.16b
+       pmull           T2.1q, XL.1d, MASK.1d
+
+       enc_round       KS1, v29
+
+       mov             XH.d[0], XM.d[1]
+       mov             XM.d[1], XL.d[0]
+
+       aese            KS0.16b, v30.16b
+
+       eor             XL.16b, XM.16b, T2.16b
+
+       aese            KS1.16b, v30.16b
+
+       ext             T2.16b, XL.16b, XL.16b, #8
+
+       eor             KS0.16b, KS0.16b, v31.16b
+
+       pmull           XL.1q, XL.1d, MASK.1d
+       eor             T2.16b, T2.16b, XH.16b
+
+       eor             KS1.16b, KS1.16b, v31.16b
 
        eor             XL.16b, XL.16b, T2.16b
 
        .if             \enc == 0
-       eor             INP.16b, INP.16b, KS.16b
-       st1             {INP.16b}, [x2], #16
+       eor             INP0.16b, INP0.16b, KS0.16b
+       eor             INP1.16b, INP1.16b, KS1.16b
        .endif
 
+       st1             {INP0.16b-INP1.16b}, [x2], #32
+
        cbnz            w0, 0b
 
 CPU_LE(        rev             x8, x8          )
@@ -421,16 +484,20 @@ CPU_LE(   rev             x8, x8          )
        str             x8, [x5, #8]                    // store lower counter
 
        .if             \enc == 1
-       st1             {KS.16b}, [x10]
+       st1             {KS0.16b-KS1.16b}, [x10]
        .endif
 
        ret
 
 2:     b.eq            3f                              // AES-192?
-       enc_round       CTR, v17
-       enc_round       CTR, v18
-3:     enc_round       CTR, v19
-       enc_round       CTR, v20
+       enc_round       KS0, v17
+       enc_round       KS1, v17
+       enc_round       KS0, v18
+       enc_round       KS1, v18
+3:     enc_round       KS0, v19
+       enc_round       KS1, v19
+       enc_round       KS0, v20
+       enc_round       KS1, v20
        b               1b
        .endm
 
index 18b4d1b96a7e914c7176c464949692c3033be994..8ff6732c4fb580c65b97391110c73db4b7a692e6 100644 (file)
@@ -348,9 +348,10 @@ static int gcm_encrypt(struct aead_request *req)
        struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
        struct skcipher_walk walk;
        u8 iv[AES_BLOCK_SIZE];
-       u8 ks[AES_BLOCK_SIZE];
+       u8 ks[2 * AES_BLOCK_SIZE];
        u8 tag[AES_BLOCK_SIZE];
        u64 dg[2] = {};
+       int nrounds = num_rounds(&ctx->aes_key);
        int err;
 
        if (req->assoclen)
@@ -362,32 +363,31 @@ static int gcm_encrypt(struct aead_request *req)
        if (likely(may_use_simd())) {
                kernel_neon_begin();
 
-               pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc,
-                                       num_rounds(&ctx->aes_key));
+               pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
                put_unaligned_be32(2, iv + GCM_IV_SIZE);
-               pmull_gcm_encrypt_block(ks, iv, NULL,
-                                       num_rounds(&ctx->aes_key));
+               pmull_gcm_encrypt_block(ks, iv, NULL, nrounds);
                put_unaligned_be32(3, iv + GCM_IV_SIZE);
+               pmull_gcm_encrypt_block(ks + AES_BLOCK_SIZE, iv, NULL, nrounds);
+               put_unaligned_be32(4, iv + GCM_IV_SIZE);
                kernel_neon_end();
 
                err = skcipher_walk_aead_encrypt(&walk, req, false);
 
-               while (walk.nbytes >= AES_BLOCK_SIZE) {
-                       int blocks = walk.nbytes / AES_BLOCK_SIZE;
+               while (walk.nbytes >= 2 * AES_BLOCK_SIZE) {
+                       int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
 
                        kernel_neon_begin();
                        pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr,
                                          walk.src.virt.addr, &ctx->ghash_key,
-                                         iv, ctx->aes_key.key_enc,
-                                         num_rounds(&ctx->aes_key), ks);
+                                         iv, ctx->aes_key.key_enc, nrounds,
+                                         ks);
                        kernel_neon_end();
 
                        err = skcipher_walk_done(&walk,
-                                                walk.nbytes % AES_BLOCK_SIZE);
+                                       walk.nbytes % (2 * AES_BLOCK_SIZE));
                }
        } else {
-               __aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv,
-                                   num_rounds(&ctx->aes_key));
+               __aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
                put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
                err = skcipher_walk_aead_encrypt(&walk, req, false);
@@ -399,8 +399,7 @@ static int gcm_encrypt(struct aead_request *req)
 
                        do {
                                __aes_arm64_encrypt(ctx->aes_key.key_enc,
-                                                   ks, iv,
-                                                   num_rounds(&ctx->aes_key));
+                                                   ks, iv, nrounds);
                                crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE);
                                crypto_inc(iv, AES_BLOCK_SIZE);
 
@@ -417,19 +416,28 @@ static int gcm_encrypt(struct aead_request *req)
                }
                if (walk.nbytes)
                        __aes_arm64_encrypt(ctx->aes_key.key_enc, ks, iv,
-                                           num_rounds(&ctx->aes_key));
+                                           nrounds);
        }
 
        /* handle the tail */
        if (walk.nbytes) {
                u8 buf[GHASH_BLOCK_SIZE];
+               unsigned int nbytes = walk.nbytes;
+               u8 *dst = walk.dst.virt.addr;
+               u8 *head = NULL;
 
                crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks,
                               walk.nbytes);
 
-               memcpy(buf, walk.dst.virt.addr, walk.nbytes);
-               memset(buf + walk.nbytes, 0, GHASH_BLOCK_SIZE - walk.nbytes);
-               ghash_do_update(1, dg, buf, &ctx->ghash_key, NULL);
+               if (walk.nbytes > GHASH_BLOCK_SIZE) {
+                       head = dst;
+                       dst += GHASH_BLOCK_SIZE;
+                       nbytes %= GHASH_BLOCK_SIZE;
+               }
+
+               memcpy(buf, dst, nbytes);
+               memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
+               ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head);
 
                err = skcipher_walk_done(&walk, 0);
        }
@@ -452,10 +460,11 @@ static int gcm_decrypt(struct aead_request *req)
        struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
        unsigned int authsize = crypto_aead_authsize(aead);
        struct skcipher_walk walk;
-       u8 iv[AES_BLOCK_SIZE];
+       u8 iv[2 * AES_BLOCK_SIZE];
        u8 tag[AES_BLOCK_SIZE];
-       u8 buf[GHASH_BLOCK_SIZE];
+       u8 buf[2 * GHASH_BLOCK_SIZE];
        u64 dg[2] = {};
+       int nrounds = num_rounds(&ctx->aes_key);
        int err;
 
        if (req->assoclen)
@@ -466,37 +475,44 @@ static int gcm_decrypt(struct aead_request *req)
 
        if (likely(may_use_simd())) {
                kernel_neon_begin();
-
-               pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc,
-                                       num_rounds(&ctx->aes_key));
+               pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
                put_unaligned_be32(2, iv + GCM_IV_SIZE);
                kernel_neon_end();
 
                err = skcipher_walk_aead_decrypt(&walk, req, false);
 
-               while (walk.nbytes >= AES_BLOCK_SIZE) {
-                       int blocks = walk.nbytes / AES_BLOCK_SIZE;
+               while (walk.nbytes >= 2 * AES_BLOCK_SIZE) {
+                       int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
 
                        kernel_neon_begin();
                        pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr,
                                          walk.src.virt.addr, &ctx->ghash_key,
-                                         iv, ctx->aes_key.key_enc,
-                                         num_rounds(&ctx->aes_key));
+                                         iv, ctx->aes_key.key_enc, nrounds);
                        kernel_neon_end();
 
                        err = skcipher_walk_done(&walk,
-                                                walk.nbytes % AES_BLOCK_SIZE);
+                                       walk.nbytes % (2 * AES_BLOCK_SIZE));
                }
+
                if (walk.nbytes) {
+                       u8 *iv2 = iv + AES_BLOCK_SIZE;
+
+                       if (walk.nbytes > AES_BLOCK_SIZE) {
+                               memcpy(iv2, iv, AES_BLOCK_SIZE);
+                               crypto_inc(iv2, AES_BLOCK_SIZE);
+                       }
+
                        kernel_neon_begin();
                        pmull_gcm_encrypt_block(iv, iv, ctx->aes_key.key_enc,
-                                               num_rounds(&ctx->aes_key));
+                                               nrounds);
+
+                       if (walk.nbytes > AES_BLOCK_SIZE)
+                               pmull_gcm_encrypt_block(iv2, iv2, NULL,
+                                                       nrounds);
                        kernel_neon_end();
                }
-
        } else {
-               __aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv,
-                                   num_rounds(&ctx->aes_key));
+               __aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
                put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
                err = skcipher_walk_aead_decrypt(&walk, req, false);
@@ -511,8 +527,7 @@ static int gcm_decrypt(struct aead_request *req)
 
                        do {
                                __aes_arm64_encrypt(ctx->aes_key.key_enc,
-                                                   buf, iv,
-                                                   num_rounds(&ctx->aes_key));
+                                                   buf, iv, nrounds);
                                crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
                                crypto_inc(iv, AES_BLOCK_SIZE);
 
@@ -525,14 +540,24 @@ static int gcm_decrypt(struct aead_request *req)
                }
                if (walk.nbytes)
                        __aes_arm64_encrypt(ctx->aes_key.key_enc, iv, iv,
-                                           num_rounds(&ctx->aes_key));
+                                           nrounds);
        }
 
        /* handle the tail */
        if (walk.nbytes) {
-               memcpy(buf, walk.src.virt.addr, walk.nbytes);
-               memset(buf + walk.nbytes, 0, GHASH_BLOCK_SIZE - walk.nbytes);
-               ghash_do_update(1, dg, buf, &ctx->ghash_key, NULL);
+               const u8 *src = walk.src.virt.addr;
+               const u8 *head = NULL;
+               unsigned int nbytes = walk.nbytes;
+
+               if (walk.nbytes > GHASH_BLOCK_SIZE) {
+                       head = src;
+                       src += GHASH_BLOCK_SIZE;
+                       nbytes %= GHASH_BLOCK_SIZE;
+               }
+
+               memcpy(buf, src, nbytes);
+               memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
+               ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head);
 
                crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv,
                               walk.nbytes);
@@ -557,7 +582,7 @@ static int gcm_decrypt(struct aead_request *req)
 
 static struct aead_alg gcm_aes_alg = {
        .ivsize                 = GCM_IV_SIZE,
-       .chunksize              = AES_BLOCK_SIZE,
+       .chunksize              = 2 * AES_BLOCK_SIZE,
        .maxauthsize            = AES_BLOCK_SIZE,
        .setkey                 = gcm_setkey,
        .setauthsize            = gcm_setauthsize,