Merge git://git.kernel.org/pub/scm/linux/kernel/git/pablo/nf
[sfrench/cifs-2.6.git] / arch / arm64 / crypto / aes-glue.c
1 /*
2  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
3  *
4  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10
11 #include <asm/neon.h>
12 #include <asm/hwcap.h>
13 #include <asm/simd.h>
14 #include <crypto/aes.h>
15 #include <crypto/internal/hash.h>
16 #include <crypto/internal/simd.h>
17 #include <crypto/internal/skcipher.h>
18 #include <crypto/scatterwalk.h>
19 #include <linux/module.h>
20 #include <linux/cpufeature.h>
21 #include <crypto/xts.h>
22
23 #include "aes-ce-setkey.h"
24 #include "aes-ctr-fallback.h"
25
26 #ifdef USE_V8_CRYPTO_EXTENSIONS
27 #define MODE                    "ce"
28 #define PRIO                    300
29 #define aes_setkey              ce_aes_setkey
30 #define aes_expandkey           ce_aes_expandkey
31 #define aes_ecb_encrypt         ce_aes_ecb_encrypt
32 #define aes_ecb_decrypt         ce_aes_ecb_decrypt
33 #define aes_cbc_encrypt         ce_aes_cbc_encrypt
34 #define aes_cbc_decrypt         ce_aes_cbc_decrypt
35 #define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
36 #define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
37 #define aes_ctr_encrypt         ce_aes_ctr_encrypt
38 #define aes_xts_encrypt         ce_aes_xts_encrypt
39 #define aes_xts_decrypt         ce_aes_xts_decrypt
40 #define aes_mac_update          ce_aes_mac_update
41 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
42 #else
43 #define MODE                    "neon"
44 #define PRIO                    200
45 #define aes_setkey              crypto_aes_set_key
46 #define aes_expandkey           crypto_aes_expand_key
47 #define aes_ecb_encrypt         neon_aes_ecb_encrypt
48 #define aes_ecb_decrypt         neon_aes_ecb_decrypt
49 #define aes_cbc_encrypt         neon_aes_cbc_encrypt
50 #define aes_cbc_decrypt         neon_aes_cbc_decrypt
51 #define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
52 #define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
53 #define aes_ctr_encrypt         neon_aes_ctr_encrypt
54 #define aes_xts_encrypt         neon_aes_xts_encrypt
55 #define aes_xts_decrypt         neon_aes_xts_decrypt
56 #define aes_mac_update          neon_aes_mac_update
57 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
58 MODULE_ALIAS_CRYPTO("ecb(aes)");
59 MODULE_ALIAS_CRYPTO("cbc(aes)");
60 MODULE_ALIAS_CRYPTO("ctr(aes)");
61 MODULE_ALIAS_CRYPTO("xts(aes)");
62 MODULE_ALIAS_CRYPTO("cmac(aes)");
63 MODULE_ALIAS_CRYPTO("xcbc(aes)");
64 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
65 #endif
66
67 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
68 MODULE_LICENSE("GPL v2");
69
70 /* defined in aes-modes.S */
71 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
72                                 int rounds, int blocks);
73 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
74                                 int rounds, int blocks);
75
76 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
77                                 int rounds, int blocks, u8 iv[]);
78 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
79                                 int rounds, int blocks, u8 iv[]);
80
81 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
82                                 int rounds, int bytes, u8 const iv[]);
83 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
84                                 int rounds, int bytes, u8 const iv[]);
85
86 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
87                                 int rounds, int blocks, u8 ctr[]);
88
89 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
90                                 int rounds, int blocks, u32 const rk2[], u8 iv[],
91                                 int first);
92 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
93                                 int rounds, int blocks, u32 const rk2[], u8 iv[],
94                                 int first);
95
96 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
97                                int blocks, u8 dg[], int enc_before,
98                                int enc_after);
99
100 struct cts_cbc_req_ctx {
101         struct scatterlist sg_src[2];
102         struct scatterlist sg_dst[2];
103         struct skcipher_request subreq;
104 };
105
106 struct crypto_aes_xts_ctx {
107         struct crypto_aes_ctx key1;
108         struct crypto_aes_ctx __aligned(8) key2;
109 };
110
111 struct mac_tfm_ctx {
112         struct crypto_aes_ctx key;
113         u8 __aligned(8) consts[];
114 };
115
116 struct mac_desc_ctx {
117         unsigned int len;
118         u8 dg[AES_BLOCK_SIZE];
119 };
120
121 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
122                                unsigned int key_len)
123 {
124         return aes_setkey(crypto_skcipher_tfm(tfm), in_key, key_len);
125 }
126
127 static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
128                        unsigned int key_len)
129 {
130         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
131         int ret;
132
133         ret = xts_verify_key(tfm, in_key, key_len);
134         if (ret)
135                 return ret;
136
137         ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
138         if (!ret)
139                 ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
140                                     key_len / 2);
141         if (!ret)
142                 return 0;
143
144         crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
145         return -EINVAL;
146 }
147
148 static int ecb_encrypt(struct skcipher_request *req)
149 {
150         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
151         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
152         int err, rounds = 6 + ctx->key_length / 4;
153         struct skcipher_walk walk;
154         unsigned int blocks;
155
156         err = skcipher_walk_virt(&walk, req, false);
157
158         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
159                 kernel_neon_begin();
160                 aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
161                                 ctx->key_enc, rounds, blocks);
162                 kernel_neon_end();
163                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
164         }
165         return err;
166 }
167
168 static int ecb_decrypt(struct skcipher_request *req)
169 {
170         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
171         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
172         int err, rounds = 6 + ctx->key_length / 4;
173         struct skcipher_walk walk;
174         unsigned int blocks;
175
176         err = skcipher_walk_virt(&walk, req, false);
177
178         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
179                 kernel_neon_begin();
180                 aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
181                                 ctx->key_dec, rounds, blocks);
182                 kernel_neon_end();
183                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
184         }
185         return err;
186 }
187
188 static int cbc_encrypt(struct skcipher_request *req)
189 {
190         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
191         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
192         int err, rounds = 6 + ctx->key_length / 4;
193         struct skcipher_walk walk;
194         unsigned int blocks;
195
196         err = skcipher_walk_virt(&walk, req, false);
197
198         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
199                 kernel_neon_begin();
200                 aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
201                                 ctx->key_enc, rounds, blocks, walk.iv);
202                 kernel_neon_end();
203                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
204         }
205         return err;
206 }
207
208 static int cbc_decrypt(struct skcipher_request *req)
209 {
210         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
211         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
212         int err, rounds = 6 + ctx->key_length / 4;
213         struct skcipher_walk walk;
214         unsigned int blocks;
215
216         err = skcipher_walk_virt(&walk, req, false);
217
218         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
219                 kernel_neon_begin();
220                 aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
221                                 ctx->key_dec, rounds, blocks, walk.iv);
222                 kernel_neon_end();
223                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
224         }
225         return err;
226 }
227
228 static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
229 {
230         crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
231         return 0;
232 }
233
234 static int cts_cbc_encrypt(struct skcipher_request *req)
235 {
236         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
237         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
238         struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
239         int err, rounds = 6 + ctx->key_length / 4;
240         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
241         struct scatterlist *src = req->src, *dst = req->dst;
242         struct skcipher_walk walk;
243
244         skcipher_request_set_tfm(&rctx->subreq, tfm);
245
246         if (req->cryptlen <= AES_BLOCK_SIZE) {
247                 if (req->cryptlen < AES_BLOCK_SIZE)
248                         return -EINVAL;
249                 cbc_blocks = 1;
250         }
251
252         if (cbc_blocks > 0) {
253                 unsigned int blocks;
254
255                 skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
256                                            cbc_blocks * AES_BLOCK_SIZE,
257                                            req->iv);
258
259                 err = skcipher_walk_virt(&walk, &rctx->subreq, false);
260
261                 while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
262                         kernel_neon_begin();
263                         aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
264                                         ctx->key_enc, rounds, blocks, walk.iv);
265                         kernel_neon_end();
266                         err = skcipher_walk_done(&walk,
267                                                  walk.nbytes % AES_BLOCK_SIZE);
268                 }
269                 if (err)
270                         return err;
271
272                 if (req->cryptlen == AES_BLOCK_SIZE)
273                         return 0;
274
275                 dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
276                                              rctx->subreq.cryptlen);
277                 if (req->dst != req->src)
278                         dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
279                                                rctx->subreq.cryptlen);
280         }
281
282         /* handle ciphertext stealing */
283         skcipher_request_set_crypt(&rctx->subreq, src, dst,
284                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
285                                    req->iv);
286
287         err = skcipher_walk_virt(&walk, &rctx->subreq, false);
288         if (err)
289                 return err;
290
291         kernel_neon_begin();
292         aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
293                             ctx->key_enc, rounds, walk.nbytes, walk.iv);
294         kernel_neon_end();
295
296         return skcipher_walk_done(&walk, 0);
297 }
298
299 static int cts_cbc_decrypt(struct skcipher_request *req)
300 {
301         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
302         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
303         struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
304         int err, rounds = 6 + ctx->key_length / 4;
305         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
306         struct scatterlist *src = req->src, *dst = req->dst;
307         struct skcipher_walk walk;
308
309         skcipher_request_set_tfm(&rctx->subreq, tfm);
310
311         if (req->cryptlen <= AES_BLOCK_SIZE) {
312                 if (req->cryptlen < AES_BLOCK_SIZE)
313                         return -EINVAL;
314                 cbc_blocks = 1;
315         }
316
317         if (cbc_blocks > 0) {
318                 unsigned int blocks;
319
320                 skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
321                                            cbc_blocks * AES_BLOCK_SIZE,
322                                            req->iv);
323
324                 err = skcipher_walk_virt(&walk, &rctx->subreq, false);
325
326                 while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
327                         kernel_neon_begin();
328                         aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
329                                         ctx->key_dec, rounds, blocks, walk.iv);
330                         kernel_neon_end();
331                         err = skcipher_walk_done(&walk,
332                                                  walk.nbytes % AES_BLOCK_SIZE);
333                 }
334                 if (err)
335                         return err;
336
337                 if (req->cryptlen == AES_BLOCK_SIZE)
338                         return 0;
339
340                 dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
341                                              rctx->subreq.cryptlen);
342                 if (req->dst != req->src)
343                         dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
344                                                rctx->subreq.cryptlen);
345         }
346
347         /* handle ciphertext stealing */
348         skcipher_request_set_crypt(&rctx->subreq, src, dst,
349                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
350                                    req->iv);
351
352         err = skcipher_walk_virt(&walk, &rctx->subreq, false);
353         if (err)
354                 return err;
355
356         kernel_neon_begin();
357         aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
358                             ctx->key_dec, rounds, walk.nbytes, walk.iv);
359         kernel_neon_end();
360
361         return skcipher_walk_done(&walk, 0);
362 }
363
364 static int ctr_encrypt(struct skcipher_request *req)
365 {
366         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
367         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
368         int err, rounds = 6 + ctx->key_length / 4;
369         struct skcipher_walk walk;
370         int blocks;
371
372         err = skcipher_walk_virt(&walk, req, false);
373
374         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
375                 kernel_neon_begin();
376                 aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
377                                 ctx->key_enc, rounds, blocks, walk.iv);
378                 kernel_neon_end();
379                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
380         }
381         if (walk.nbytes) {
382                 u8 __aligned(8) tail[AES_BLOCK_SIZE];
383                 unsigned int nbytes = walk.nbytes;
384                 u8 *tdst = walk.dst.virt.addr;
385                 u8 *tsrc = walk.src.virt.addr;
386
387                 /*
388                  * Tell aes_ctr_encrypt() to process a tail block.
389                  */
390                 blocks = -1;
391
392                 kernel_neon_begin();
393                 aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
394                                 blocks, walk.iv);
395                 kernel_neon_end();
396                 crypto_xor_cpy(tdst, tsrc, tail, nbytes);
397                 err = skcipher_walk_done(&walk, 0);
398         }
399
400         return err;
401 }
402
403 static int ctr_encrypt_sync(struct skcipher_request *req)
404 {
405         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
406         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
407
408         if (!may_use_simd())
409                 return aes_ctr_encrypt_fallback(ctx, req);
410
411         return ctr_encrypt(req);
412 }
413
414 static int xts_encrypt(struct skcipher_request *req)
415 {
416         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
417         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
418         int err, first, rounds = 6 + ctx->key1.key_length / 4;
419         struct skcipher_walk walk;
420         unsigned int blocks;
421
422         err = skcipher_walk_virt(&walk, req, false);
423
424         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
425                 kernel_neon_begin();
426                 aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
427                                 ctx->key1.key_enc, rounds, blocks,
428                                 ctx->key2.key_enc, walk.iv, first);
429                 kernel_neon_end();
430                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
431         }
432
433         return err;
434 }
435
436 static int xts_decrypt(struct skcipher_request *req)
437 {
438         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
439         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
440         int err, first, rounds = 6 + ctx->key1.key_length / 4;
441         struct skcipher_walk walk;
442         unsigned int blocks;
443
444         err = skcipher_walk_virt(&walk, req, false);
445
446         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
447                 kernel_neon_begin();
448                 aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
449                                 ctx->key1.key_dec, rounds, blocks,
450                                 ctx->key2.key_enc, walk.iv, first);
451                 kernel_neon_end();
452                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
453         }
454
455         return err;
456 }
457
458 static struct skcipher_alg aes_algs[] = { {
459         .base = {
460                 .cra_name               = "__ecb(aes)",
461                 .cra_driver_name        = "__ecb-aes-" MODE,
462                 .cra_priority           = PRIO,
463                 .cra_flags              = CRYPTO_ALG_INTERNAL,
464                 .cra_blocksize          = AES_BLOCK_SIZE,
465                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
466                 .cra_module             = THIS_MODULE,
467         },
468         .min_keysize    = AES_MIN_KEY_SIZE,
469         .max_keysize    = AES_MAX_KEY_SIZE,
470         .setkey         = skcipher_aes_setkey,
471         .encrypt        = ecb_encrypt,
472         .decrypt        = ecb_decrypt,
473 }, {
474         .base = {
475                 .cra_name               = "__cbc(aes)",
476                 .cra_driver_name        = "__cbc-aes-" MODE,
477                 .cra_priority           = PRIO,
478                 .cra_flags              = CRYPTO_ALG_INTERNAL,
479                 .cra_blocksize          = AES_BLOCK_SIZE,
480                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
481                 .cra_module             = THIS_MODULE,
482         },
483         .min_keysize    = AES_MIN_KEY_SIZE,
484         .max_keysize    = AES_MAX_KEY_SIZE,
485         .ivsize         = AES_BLOCK_SIZE,
486         .setkey         = skcipher_aes_setkey,
487         .encrypt        = cbc_encrypt,
488         .decrypt        = cbc_decrypt,
489 }, {
490         .base = {
491                 .cra_name               = "__cts(cbc(aes))",
492                 .cra_driver_name        = "__cts-cbc-aes-" MODE,
493                 .cra_priority           = PRIO,
494                 .cra_flags              = CRYPTO_ALG_INTERNAL,
495                 .cra_blocksize          = AES_BLOCK_SIZE,
496                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
497                 .cra_module             = THIS_MODULE,
498         },
499         .min_keysize    = AES_MIN_KEY_SIZE,
500         .max_keysize    = AES_MAX_KEY_SIZE,
501         .ivsize         = AES_BLOCK_SIZE,
502         .walksize       = 2 * AES_BLOCK_SIZE,
503         .setkey         = skcipher_aes_setkey,
504         .encrypt        = cts_cbc_encrypt,
505         .decrypt        = cts_cbc_decrypt,
506         .init           = cts_cbc_init_tfm,
507 }, {
508         .base = {
509                 .cra_name               = "__ctr(aes)",
510                 .cra_driver_name        = "__ctr-aes-" MODE,
511                 .cra_priority           = PRIO,
512                 .cra_flags              = CRYPTO_ALG_INTERNAL,
513                 .cra_blocksize          = 1,
514                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
515                 .cra_module             = THIS_MODULE,
516         },
517         .min_keysize    = AES_MIN_KEY_SIZE,
518         .max_keysize    = AES_MAX_KEY_SIZE,
519         .ivsize         = AES_BLOCK_SIZE,
520         .chunksize      = AES_BLOCK_SIZE,
521         .setkey         = skcipher_aes_setkey,
522         .encrypt        = ctr_encrypt,
523         .decrypt        = ctr_encrypt,
524 }, {
525         .base = {
526                 .cra_name               = "ctr(aes)",
527                 .cra_driver_name        = "ctr-aes-" MODE,
528                 .cra_priority           = PRIO - 1,
529                 .cra_blocksize          = 1,
530                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
531                 .cra_module             = THIS_MODULE,
532         },
533         .min_keysize    = AES_MIN_KEY_SIZE,
534         .max_keysize    = AES_MAX_KEY_SIZE,
535         .ivsize         = AES_BLOCK_SIZE,
536         .chunksize      = AES_BLOCK_SIZE,
537         .setkey         = skcipher_aes_setkey,
538         .encrypt        = ctr_encrypt_sync,
539         .decrypt        = ctr_encrypt_sync,
540 }, {
541         .base = {
542                 .cra_name               = "__xts(aes)",
543                 .cra_driver_name        = "__xts-aes-" MODE,
544                 .cra_priority           = PRIO,
545                 .cra_flags              = CRYPTO_ALG_INTERNAL,
546                 .cra_blocksize          = AES_BLOCK_SIZE,
547                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
548                 .cra_module             = THIS_MODULE,
549         },
550         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
551         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
552         .ivsize         = AES_BLOCK_SIZE,
553         .setkey         = xts_set_key,
554         .encrypt        = xts_encrypt,
555         .decrypt        = xts_decrypt,
556 } };
557
558 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
559                          unsigned int key_len)
560 {
561         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
562         int err;
563
564         err = aes_expandkey(&ctx->key, in_key, key_len);
565         if (err)
566                 crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
567
568         return err;
569 }
570
571 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
572 {
573         u64 a = be64_to_cpu(x->a);
574         u64 b = be64_to_cpu(x->b);
575
576         y->a = cpu_to_be64((a << 1) | (b >> 63));
577         y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
578 }
579
580 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
581                        unsigned int key_len)
582 {
583         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
584         be128 *consts = (be128 *)ctx->consts;
585         int rounds = 6 + key_len / 4;
586         int err;
587
588         err = cbcmac_setkey(tfm, in_key, key_len);
589         if (err)
590                 return err;
591
592         /* encrypt the zero vector */
593         kernel_neon_begin();
594         aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
595                         rounds, 1);
596         kernel_neon_end();
597
598         cmac_gf128_mul_by_x(consts, consts);
599         cmac_gf128_mul_by_x(consts + 1, consts);
600
601         return 0;
602 }
603
604 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
605                        unsigned int key_len)
606 {
607         static u8 const ks[3][AES_BLOCK_SIZE] = {
608                 { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
609                 { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
610                 { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
611         };
612
613         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
614         int rounds = 6 + key_len / 4;
615         u8 key[AES_BLOCK_SIZE];
616         int err;
617
618         err = cbcmac_setkey(tfm, in_key, key_len);
619         if (err)
620                 return err;
621
622         kernel_neon_begin();
623         aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
624         aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
625         kernel_neon_end();
626
627         return cbcmac_setkey(tfm, key, sizeof(key));
628 }
629
630 static int mac_init(struct shash_desc *desc)
631 {
632         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
633
634         memset(ctx->dg, 0, AES_BLOCK_SIZE);
635         ctx->len = 0;
636
637         return 0;
638 }
639
640 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
641                           u8 dg[], int enc_before, int enc_after)
642 {
643         int rounds = 6 + ctx->key_length / 4;
644
645         if (may_use_simd()) {
646                 kernel_neon_begin();
647                 aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
648                                enc_after);
649                 kernel_neon_end();
650         } else {
651                 if (enc_before)
652                         __aes_arm64_encrypt(ctx->key_enc, dg, dg, rounds);
653
654                 while (blocks--) {
655                         crypto_xor(dg, in, AES_BLOCK_SIZE);
656                         in += AES_BLOCK_SIZE;
657
658                         if (blocks || enc_after)
659                                 __aes_arm64_encrypt(ctx->key_enc, dg, dg,
660                                                     rounds);
661                 }
662         }
663 }
664
665 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
666 {
667         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
668         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
669
670         while (len > 0) {
671                 unsigned int l;
672
673                 if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
674                     (ctx->len + len) > AES_BLOCK_SIZE) {
675
676                         int blocks = len / AES_BLOCK_SIZE;
677
678                         len %= AES_BLOCK_SIZE;
679
680                         mac_do_update(&tctx->key, p, blocks, ctx->dg,
681                                       (ctx->len != 0), (len != 0));
682
683                         p += blocks * AES_BLOCK_SIZE;
684
685                         if (!len) {
686                                 ctx->len = AES_BLOCK_SIZE;
687                                 break;
688                         }
689                         ctx->len = 0;
690                 }
691
692                 l = min(len, AES_BLOCK_SIZE - ctx->len);
693
694                 if (l <= AES_BLOCK_SIZE) {
695                         crypto_xor(ctx->dg + ctx->len, p, l);
696                         ctx->len += l;
697                         len -= l;
698                         p += l;
699                 }
700         }
701
702         return 0;
703 }
704
705 static int cbcmac_final(struct shash_desc *desc, u8 *out)
706 {
707         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
708         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
709
710         mac_do_update(&tctx->key, NULL, 0, ctx->dg, 1, 0);
711
712         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
713
714         return 0;
715 }
716
717 static int cmac_final(struct shash_desc *desc, u8 *out)
718 {
719         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
720         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
721         u8 *consts = tctx->consts;
722
723         if (ctx->len != AES_BLOCK_SIZE) {
724                 ctx->dg[ctx->len] ^= 0x80;
725                 consts += AES_BLOCK_SIZE;
726         }
727
728         mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
729
730         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
731
732         return 0;
733 }
734
735 static struct shash_alg mac_algs[] = { {
736         .base.cra_name          = "cmac(aes)",
737         .base.cra_driver_name   = "cmac-aes-" MODE,
738         .base.cra_priority      = PRIO,
739         .base.cra_blocksize     = AES_BLOCK_SIZE,
740         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
741                                   2 * AES_BLOCK_SIZE,
742         .base.cra_module        = THIS_MODULE,
743
744         .digestsize             = AES_BLOCK_SIZE,
745         .init                   = mac_init,
746         .update                 = mac_update,
747         .final                  = cmac_final,
748         .setkey                 = cmac_setkey,
749         .descsize               = sizeof(struct mac_desc_ctx),
750 }, {
751         .base.cra_name          = "xcbc(aes)",
752         .base.cra_driver_name   = "xcbc-aes-" MODE,
753         .base.cra_priority      = PRIO,
754         .base.cra_blocksize     = AES_BLOCK_SIZE,
755         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
756                                   2 * AES_BLOCK_SIZE,
757         .base.cra_module        = THIS_MODULE,
758
759         .digestsize             = AES_BLOCK_SIZE,
760         .init                   = mac_init,
761         .update                 = mac_update,
762         .final                  = cmac_final,
763         .setkey                 = xcbc_setkey,
764         .descsize               = sizeof(struct mac_desc_ctx),
765 }, {
766         .base.cra_name          = "cbcmac(aes)",
767         .base.cra_driver_name   = "cbcmac-aes-" MODE,
768         .base.cra_priority      = PRIO,
769         .base.cra_blocksize     = 1,
770         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
771         .base.cra_module        = THIS_MODULE,
772
773         .digestsize             = AES_BLOCK_SIZE,
774         .init                   = mac_init,
775         .update                 = mac_update,
776         .final                  = cbcmac_final,
777         .setkey                 = cbcmac_setkey,
778         .descsize               = sizeof(struct mac_desc_ctx),
779 } };
780
781 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
782
783 static void aes_exit(void)
784 {
785         int i;
786
787         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
788                 if (aes_simd_algs[i])
789                         simd_skcipher_free(aes_simd_algs[i]);
790
791         crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
792         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
793 }
794
795 static int __init aes_init(void)
796 {
797         struct simd_skcipher_alg *simd;
798         const char *basename;
799         const char *algname;
800         const char *drvname;
801         int err;
802         int i;
803
804         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
805         if (err)
806                 return err;
807
808         err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
809         if (err)
810                 goto unregister_ciphers;
811
812         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
813                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
814                         continue;
815
816                 algname = aes_algs[i].base.cra_name + 2;
817                 drvname = aes_algs[i].base.cra_driver_name + 2;
818                 basename = aes_algs[i].base.cra_driver_name;
819                 simd = simd_skcipher_create_compat(algname, drvname, basename);
820                 err = PTR_ERR(simd);
821                 if (IS_ERR(simd))
822                         goto unregister_simds;
823
824                 aes_simd_algs[i] = simd;
825         }
826
827         return 0;
828
829 unregister_simds:
830         aes_exit();
831         return err;
832 unregister_ciphers:
833         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
834         return err;
835 }
836
837 #ifdef USE_V8_CRYPTO_EXTENSIONS
838 module_cpu_feature_match(AES, aes_init);
839 #else
840 module_init(aes_init);
841 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
842 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
843 #endif
844 module_exit(aes_exit);