treewide: Replace GPLv2 boilerplate/reference with SPDX - rule 500
[sfrench/cifs-2.6.git] / arch / arm / crypto / aes-ce-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * aes-ce-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/hwcap.h>
9 #include <asm/neon.h>
10 #include <crypto/aes.h>
11 #include <crypto/internal/simd.h>
12 #include <crypto/internal/skcipher.h>
13 #include <linux/cpufeature.h>
14 #include <linux/module.h>
15 #include <crypto/xts.h>
16
17 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
20
21 /* defined in aes-ce-core.S */
22 asmlinkage u32 ce_aes_sub(u32 input);
23 asmlinkage void ce_aes_invert(void *dst, void *src);
24
25 asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
26                                    int rounds, int blocks);
27 asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
28                                    int rounds, int blocks);
29
30 asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
31                                    int rounds, int blocks, u8 iv[]);
32 asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
33                                    int rounds, int blocks, u8 iv[]);
34
35 asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
36                                    int rounds, int blocks, u8 ctr[]);
37
38 asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
39                                    int rounds, int blocks, u8 iv[],
40                                    u8 const rk2[], int first);
41 asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
42                                    int rounds, int blocks, u8 iv[],
43                                    u8 const rk2[], int first);
44
45 struct aes_block {
46         u8 b[AES_BLOCK_SIZE];
47 };
48
49 static int num_rounds(struct crypto_aes_ctx *ctx)
50 {
51         /*
52          * # of rounds specified by AES:
53          * 128 bit key          10 rounds
54          * 192 bit key          12 rounds
55          * 256 bit key          14 rounds
56          * => n byte key        => 6 + (n/4) rounds
57          */
58         return 6 + ctx->key_length / 4;
59 }
60
61 static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
62                             unsigned int key_len)
63 {
64         /*
65          * The AES key schedule round constants
66          */
67         static u8 const rcon[] = {
68                 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
69         };
70
71         u32 kwords = key_len / sizeof(u32);
72         struct aes_block *key_enc, *key_dec;
73         int i, j;
74
75         if (key_len != AES_KEYSIZE_128 &&
76             key_len != AES_KEYSIZE_192 &&
77             key_len != AES_KEYSIZE_256)
78                 return -EINVAL;
79
80         memcpy(ctx->key_enc, in_key, key_len);
81         ctx->key_length = key_len;
82
83         kernel_neon_begin();
84         for (i = 0; i < sizeof(rcon); i++) {
85                 u32 *rki = ctx->key_enc + (i * kwords);
86                 u32 *rko = rki + kwords;
87
88 #ifndef CONFIG_CPU_BIG_ENDIAN
89                 rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
90                 rko[0] = rko[0] ^ rki[0] ^ rcon[i];
91 #else
92                 rko[0] = rol32(ce_aes_sub(rki[kwords - 1]), 8);
93                 rko[0] = rko[0] ^ rki[0] ^ (rcon[i] << 24);
94 #endif
95                 rko[1] = rko[0] ^ rki[1];
96                 rko[2] = rko[1] ^ rki[2];
97                 rko[3] = rko[2] ^ rki[3];
98
99                 if (key_len == AES_KEYSIZE_192) {
100                         if (i >= 7)
101                                 break;
102                         rko[4] = rko[3] ^ rki[4];
103                         rko[5] = rko[4] ^ rki[5];
104                 } else if (key_len == AES_KEYSIZE_256) {
105                         if (i >= 6)
106                                 break;
107                         rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
108                         rko[5] = rko[4] ^ rki[5];
109                         rko[6] = rko[5] ^ rki[6];
110                         rko[7] = rko[6] ^ rki[7];
111                 }
112         }
113
114         /*
115          * Generate the decryption keys for the Equivalent Inverse Cipher.
116          * This involves reversing the order of the round keys, and applying
117          * the Inverse Mix Columns transformation on all but the first and
118          * the last one.
119          */
120         key_enc = (struct aes_block *)ctx->key_enc;
121         key_dec = (struct aes_block *)ctx->key_dec;
122         j = num_rounds(ctx);
123
124         key_dec[0] = key_enc[j];
125         for (i = 1, j--; j > 0; i++, j--)
126                 ce_aes_invert(key_dec + i, key_enc + j);
127         key_dec[i] = key_enc[0];
128
129         kernel_neon_end();
130         return 0;
131 }
132
133 static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
134                          unsigned int key_len)
135 {
136         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
137         int ret;
138
139         ret = ce_aes_expandkey(ctx, in_key, key_len);
140         if (!ret)
141                 return 0;
142
143         crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
144         return -EINVAL;
145 }
146
147 struct crypto_aes_xts_ctx {
148         struct crypto_aes_ctx key1;
149         struct crypto_aes_ctx __aligned(8) key2;
150 };
151
152 static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
153                        unsigned int key_len)
154 {
155         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
156         int ret;
157
158         ret = xts_verify_key(tfm, in_key, key_len);
159         if (ret)
160                 return ret;
161
162         ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
163         if (!ret)
164                 ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
165                                        key_len / 2);
166         if (!ret)
167                 return 0;
168
169         crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
170         return -EINVAL;
171 }
172
173 static int ecb_encrypt(struct skcipher_request *req)
174 {
175         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
176         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
177         struct skcipher_walk walk;
178         unsigned int blocks;
179         int err;
180
181         err = skcipher_walk_virt(&walk, req, true);
182
183         kernel_neon_begin();
184         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
185                 ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
186                                    (u8 *)ctx->key_enc, num_rounds(ctx), blocks);
187                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
188         }
189         kernel_neon_end();
190         return err;
191 }
192
193 static int ecb_decrypt(struct skcipher_request *req)
194 {
195         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
196         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
197         struct skcipher_walk walk;
198         unsigned int blocks;
199         int err;
200
201         err = skcipher_walk_virt(&walk, req, true);
202
203         kernel_neon_begin();
204         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
205                 ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
206                                    (u8 *)ctx->key_dec, num_rounds(ctx), blocks);
207                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
208         }
209         kernel_neon_end();
210         return err;
211 }
212
213 static int cbc_encrypt(struct skcipher_request *req)
214 {
215         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
216         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
217         struct skcipher_walk walk;
218         unsigned int blocks;
219         int err;
220
221         err = skcipher_walk_virt(&walk, req, true);
222
223         kernel_neon_begin();
224         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
225                 ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
226                                    (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
227                                    walk.iv);
228                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
229         }
230         kernel_neon_end();
231         return err;
232 }
233
234 static int cbc_decrypt(struct skcipher_request *req)
235 {
236         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
237         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
238         struct skcipher_walk walk;
239         unsigned int blocks;
240         int err;
241
242         err = skcipher_walk_virt(&walk, req, true);
243
244         kernel_neon_begin();
245         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
246                 ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
247                                    (u8 *)ctx->key_dec, num_rounds(ctx), blocks,
248                                    walk.iv);
249                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
250         }
251         kernel_neon_end();
252         return err;
253 }
254
255 static int ctr_encrypt(struct skcipher_request *req)
256 {
257         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
258         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
259         struct skcipher_walk walk;
260         int err, blocks;
261
262         err = skcipher_walk_virt(&walk, req, true);
263
264         kernel_neon_begin();
265         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
266                 ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
267                                    (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
268                                    walk.iv);
269                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
270         }
271         if (walk.nbytes) {
272                 u8 __aligned(8) tail[AES_BLOCK_SIZE];
273                 unsigned int nbytes = walk.nbytes;
274                 u8 *tdst = walk.dst.virt.addr;
275                 u8 *tsrc = walk.src.virt.addr;
276
277                 /*
278                  * Tell aes_ctr_encrypt() to process a tail block.
279                  */
280                 blocks = -1;
281
282                 ce_aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc,
283                                    num_rounds(ctx), blocks, walk.iv);
284                 crypto_xor_cpy(tdst, tsrc, tail, nbytes);
285                 err = skcipher_walk_done(&walk, 0);
286         }
287         kernel_neon_end();
288
289         return err;
290 }
291
292 static int xts_encrypt(struct skcipher_request *req)
293 {
294         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
295         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
296         int err, first, rounds = num_rounds(&ctx->key1);
297         struct skcipher_walk walk;
298         unsigned int blocks;
299
300         err = skcipher_walk_virt(&walk, req, true);
301
302         kernel_neon_begin();
303         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
304                 ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
305                                    (u8 *)ctx->key1.key_enc, rounds, blocks,
306                                    walk.iv, (u8 *)ctx->key2.key_enc, first);
307                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
308         }
309         kernel_neon_end();
310
311         return err;
312 }
313
314 static int xts_decrypt(struct skcipher_request *req)
315 {
316         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
317         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
318         int err, first, rounds = num_rounds(&ctx->key1);
319         struct skcipher_walk walk;
320         unsigned int blocks;
321
322         err = skcipher_walk_virt(&walk, req, true);
323
324         kernel_neon_begin();
325         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
326                 ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
327                                    (u8 *)ctx->key1.key_dec, rounds, blocks,
328                                    walk.iv, (u8 *)ctx->key2.key_enc, first);
329                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
330         }
331         kernel_neon_end();
332
333         return err;
334 }
335
336 static struct skcipher_alg aes_algs[] = { {
337         .base = {
338                 .cra_name               = "__ecb(aes)",
339                 .cra_driver_name        = "__ecb-aes-ce",
340                 .cra_priority           = 300,
341                 .cra_flags              = CRYPTO_ALG_INTERNAL,
342                 .cra_blocksize          = AES_BLOCK_SIZE,
343                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
344                 .cra_module             = THIS_MODULE,
345         },
346         .min_keysize    = AES_MIN_KEY_SIZE,
347         .max_keysize    = AES_MAX_KEY_SIZE,
348         .setkey         = ce_aes_setkey,
349         .encrypt        = ecb_encrypt,
350         .decrypt        = ecb_decrypt,
351 }, {
352         .base = {
353                 .cra_name               = "__cbc(aes)",
354                 .cra_driver_name        = "__cbc-aes-ce",
355                 .cra_priority           = 300,
356                 .cra_flags              = CRYPTO_ALG_INTERNAL,
357                 .cra_blocksize          = AES_BLOCK_SIZE,
358                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
359                 .cra_module             = THIS_MODULE,
360         },
361         .min_keysize    = AES_MIN_KEY_SIZE,
362         .max_keysize    = AES_MAX_KEY_SIZE,
363         .ivsize         = AES_BLOCK_SIZE,
364         .setkey         = ce_aes_setkey,
365         .encrypt        = cbc_encrypt,
366         .decrypt        = cbc_decrypt,
367 }, {
368         .base = {
369                 .cra_name               = "__ctr(aes)",
370                 .cra_driver_name        = "__ctr-aes-ce",
371                 .cra_priority           = 300,
372                 .cra_flags              = CRYPTO_ALG_INTERNAL,
373                 .cra_blocksize          = 1,
374                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
375                 .cra_module             = THIS_MODULE,
376         },
377         .min_keysize    = AES_MIN_KEY_SIZE,
378         .max_keysize    = AES_MAX_KEY_SIZE,
379         .ivsize         = AES_BLOCK_SIZE,
380         .chunksize      = AES_BLOCK_SIZE,
381         .setkey         = ce_aes_setkey,
382         .encrypt        = ctr_encrypt,
383         .decrypt        = ctr_encrypt,
384 }, {
385         .base = {
386                 .cra_name               = "__xts(aes)",
387                 .cra_driver_name        = "__xts-aes-ce",
388                 .cra_priority           = 300,
389                 .cra_flags              = CRYPTO_ALG_INTERNAL,
390                 .cra_blocksize          = AES_BLOCK_SIZE,
391                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
392                 .cra_module             = THIS_MODULE,
393         },
394         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
395         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
396         .ivsize         = AES_BLOCK_SIZE,
397         .setkey         = xts_set_key,
398         .encrypt        = xts_encrypt,
399         .decrypt        = xts_decrypt,
400 } };
401
402 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
403
404 static void aes_exit(void)
405 {
406         int i;
407
408         for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
409                 simd_skcipher_free(aes_simd_algs[i]);
410
411         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
412 }
413
414 static int __init aes_init(void)
415 {
416         struct simd_skcipher_alg *simd;
417         const char *basename;
418         const char *algname;
419         const char *drvname;
420         int err;
421         int i;
422
423         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
424         if (err)
425                 return err;
426
427         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
428                 algname = aes_algs[i].base.cra_name + 2;
429                 drvname = aes_algs[i].base.cra_driver_name + 2;
430                 basename = aes_algs[i].base.cra_driver_name;
431                 simd = simd_skcipher_create_compat(algname, drvname, basename);
432                 err = PTR_ERR(simd);
433                 if (IS_ERR(simd))
434                         goto unregister_simds;
435
436                 aes_simd_algs[i] = simd;
437         }
438
439         return 0;
440
441 unregister_simds:
442         aes_exit();
443         return err;
444 }
445
446 module_cpu_feature_match(AES, aes_init);
447 module_exit(aes_exit);