Merge branch 'apw' (xfrm_user fixes)
[sfrench/cifs-2.6.git] / arch / arm / crypto / aes-neonbs-glue.c
1 /*
2  * Bit sliced AES using NEON instructions
3  *
4  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10
11 #include <asm/neon.h>
12 #include <crypto/aes.h>
13 #include <crypto/cbc.h>
14 #include <crypto/internal/simd.h>
15 #include <crypto/internal/skcipher.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
21
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
26
27 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
28
29 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
30                                   int rounds, int blocks);
31 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
32                                   int rounds, int blocks);
33
34 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
35                                   int rounds, int blocks, u8 iv[]);
36
37 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
38                                   int rounds, int blocks, u8 ctr[], u8 final[]);
39
40 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
41                                   int rounds, int blocks, u8 iv[]);
42 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
43                                   int rounds, int blocks, u8 iv[]);
44
45 asmlinkage void __aes_arm_encrypt(const u32 rk[], int rounds, const u8 in[],
46                                   u8 out[]);
47
48 struct aesbs_ctx {
49         int     rounds;
50         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
51 };
52
53 struct aesbs_cbc_ctx {
54         struct aesbs_ctx        key;
55         u32                     enc[AES_MAX_KEYLENGTH_U32];
56 };
57
58 struct aesbs_xts_ctx {
59         struct aesbs_ctx        key;
60         u32                     twkey[AES_MAX_KEYLENGTH_U32];
61 };
62
63 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
64                         unsigned int key_len)
65 {
66         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
67         struct crypto_aes_ctx rk;
68         int err;
69
70         err = crypto_aes_expand_key(&rk, in_key, key_len);
71         if (err)
72                 return err;
73
74         ctx->rounds = 6 + key_len / 4;
75
76         kernel_neon_begin();
77         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
78         kernel_neon_end();
79
80         return 0;
81 }
82
83 static int __ecb_crypt(struct skcipher_request *req,
84                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
85                                   int rounds, int blocks))
86 {
87         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
88         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
89         struct skcipher_walk walk;
90         int err;
91
92         err = skcipher_walk_virt(&walk, req, true);
93
94         kernel_neon_begin();
95         while (walk.nbytes >= AES_BLOCK_SIZE) {
96                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
97
98                 if (walk.nbytes < walk.total)
99                         blocks = round_down(blocks,
100                                             walk.stride / AES_BLOCK_SIZE);
101
102                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
103                    ctx->rounds, blocks);
104                 err = skcipher_walk_done(&walk,
105                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
106         }
107         kernel_neon_end();
108
109         return err;
110 }
111
112 static int ecb_encrypt(struct skcipher_request *req)
113 {
114         return __ecb_crypt(req, aesbs_ecb_encrypt);
115 }
116
117 static int ecb_decrypt(struct skcipher_request *req)
118 {
119         return __ecb_crypt(req, aesbs_ecb_decrypt);
120 }
121
122 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
123                             unsigned int key_len)
124 {
125         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
126         struct crypto_aes_ctx rk;
127         int err;
128
129         err = crypto_aes_expand_key(&rk, in_key, key_len);
130         if (err)
131                 return err;
132
133         ctx->key.rounds = 6 + key_len / 4;
134
135         memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
136
137         kernel_neon_begin();
138         aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
139         kernel_neon_end();
140
141         return 0;
142 }
143
144 static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
145 {
146         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
147
148         __aes_arm_encrypt(ctx->enc, ctx->key.rounds, src, dst);
149 }
150
151 static int cbc_encrypt(struct skcipher_request *req)
152 {
153         return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
154 }
155
156 static int cbc_decrypt(struct skcipher_request *req)
157 {
158         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
159         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
160         struct skcipher_walk walk;
161         int err;
162
163         err = skcipher_walk_virt(&walk, req, true);
164
165         kernel_neon_begin();
166         while (walk.nbytes >= AES_BLOCK_SIZE) {
167                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
168
169                 if (walk.nbytes < walk.total)
170                         blocks = round_down(blocks,
171                                             walk.stride / AES_BLOCK_SIZE);
172
173                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
174                                   ctx->key.rk, ctx->key.rounds, blocks,
175                                   walk.iv);
176                 err = skcipher_walk_done(&walk,
177                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
178         }
179         kernel_neon_end();
180
181         return err;
182 }
183
184 static int ctr_encrypt(struct skcipher_request *req)
185 {
186         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
187         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
188         struct skcipher_walk walk;
189         u8 buf[AES_BLOCK_SIZE];
190         int err;
191
192         err = skcipher_walk_virt(&walk, req, true);
193
194         kernel_neon_begin();
195         while (walk.nbytes > 0) {
196                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
197                 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
198
199                 if (walk.nbytes < walk.total) {
200                         blocks = round_down(blocks,
201                                             walk.stride / AES_BLOCK_SIZE);
202                         final = NULL;
203                 }
204
205                 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
206                                   ctx->rk, ctx->rounds, blocks, walk.iv, final);
207
208                 if (final) {
209                         u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
210                         u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
211
212                         if (dst != src)
213                                 memcpy(dst, src, walk.total % AES_BLOCK_SIZE);
214                         crypto_xor(dst, final, walk.total % AES_BLOCK_SIZE);
215
216                         err = skcipher_walk_done(&walk, 0);
217                         break;
218                 }
219                 err = skcipher_walk_done(&walk,
220                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
221         }
222         kernel_neon_end();
223
224         return err;
225 }
226
227 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
228                             unsigned int key_len)
229 {
230         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
231         struct crypto_aes_ctx rk;
232         int err;
233
234         err = xts_verify_key(tfm, in_key, key_len);
235         if (err)
236                 return err;
237
238         key_len /= 2;
239         err = crypto_aes_expand_key(&rk, in_key + key_len, key_len);
240         if (err)
241                 return err;
242
243         memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
244
245         return aesbs_setkey(tfm, in_key, key_len);
246 }
247
248 static int __xts_crypt(struct skcipher_request *req,
249                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
250                                   int rounds, int blocks, u8 iv[]))
251 {
252         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
253         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
254         struct skcipher_walk walk;
255         int err;
256
257         err = skcipher_walk_virt(&walk, req, true);
258
259         __aes_arm_encrypt(ctx->twkey, ctx->key.rounds, walk.iv, walk.iv);
260
261         kernel_neon_begin();
262         while (walk.nbytes >= AES_BLOCK_SIZE) {
263                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
264
265                 if (walk.nbytes < walk.total)
266                         blocks = round_down(blocks,
267                                             walk.stride / AES_BLOCK_SIZE);
268
269                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
270                    ctx->key.rounds, blocks, walk.iv);
271                 err = skcipher_walk_done(&walk,
272                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
273         }
274         kernel_neon_end();
275
276         return err;
277 }
278
279 static int xts_encrypt(struct skcipher_request *req)
280 {
281         return __xts_crypt(req, aesbs_xts_encrypt);
282 }
283
284 static int xts_decrypt(struct skcipher_request *req)
285 {
286         return __xts_crypt(req, aesbs_xts_decrypt);
287 }
288
289 static struct skcipher_alg aes_algs[] = { {
290         .base.cra_name          = "__ecb(aes)",
291         .base.cra_driver_name   = "__ecb-aes-neonbs",
292         .base.cra_priority      = 250,
293         .base.cra_blocksize     = AES_BLOCK_SIZE,
294         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
295         .base.cra_module        = THIS_MODULE,
296         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
297
298         .min_keysize            = AES_MIN_KEY_SIZE,
299         .max_keysize            = AES_MAX_KEY_SIZE,
300         .walksize               = 8 * AES_BLOCK_SIZE,
301         .setkey                 = aesbs_setkey,
302         .encrypt                = ecb_encrypt,
303         .decrypt                = ecb_decrypt,
304 }, {
305         .base.cra_name          = "__cbc(aes)",
306         .base.cra_driver_name   = "__cbc-aes-neonbs",
307         .base.cra_priority      = 250,
308         .base.cra_blocksize     = AES_BLOCK_SIZE,
309         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
310         .base.cra_module        = THIS_MODULE,
311         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
312
313         .min_keysize            = AES_MIN_KEY_SIZE,
314         .max_keysize            = AES_MAX_KEY_SIZE,
315         .walksize               = 8 * AES_BLOCK_SIZE,
316         .ivsize                 = AES_BLOCK_SIZE,
317         .setkey                 = aesbs_cbc_setkey,
318         .encrypt                = cbc_encrypt,
319         .decrypt                = cbc_decrypt,
320 }, {
321         .base.cra_name          = "__ctr(aes)",
322         .base.cra_driver_name   = "__ctr-aes-neonbs",
323         .base.cra_priority      = 250,
324         .base.cra_blocksize     = 1,
325         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
326         .base.cra_module        = THIS_MODULE,
327         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
328
329         .min_keysize            = AES_MIN_KEY_SIZE,
330         .max_keysize            = AES_MAX_KEY_SIZE,
331         .chunksize              = AES_BLOCK_SIZE,
332         .walksize               = 8 * AES_BLOCK_SIZE,
333         .ivsize                 = AES_BLOCK_SIZE,
334         .setkey                 = aesbs_setkey,
335         .encrypt                = ctr_encrypt,
336         .decrypt                = ctr_encrypt,
337 }, {
338         .base.cra_name          = "__xts(aes)",
339         .base.cra_driver_name   = "__xts-aes-neonbs",
340         .base.cra_priority      = 250,
341         .base.cra_blocksize     = AES_BLOCK_SIZE,
342         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
343         .base.cra_module        = THIS_MODULE,
344         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
345
346         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
347         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
348         .walksize               = 8 * AES_BLOCK_SIZE,
349         .ivsize                 = AES_BLOCK_SIZE,
350         .setkey                 = aesbs_xts_setkey,
351         .encrypt                = xts_encrypt,
352         .decrypt                = xts_decrypt,
353 } };
354
355 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
356
357 static void aes_exit(void)
358 {
359         int i;
360
361         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
362                 if (aes_simd_algs[i])
363                         simd_skcipher_free(aes_simd_algs[i]);
364
365         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
366 }
367
368 static int __init aes_init(void)
369 {
370         struct simd_skcipher_alg *simd;
371         const char *basename;
372         const char *algname;
373         const char *drvname;
374         int err;
375         int i;
376
377         if (!(elf_hwcap & HWCAP_NEON))
378                 return -ENODEV;
379
380         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
381         if (err)
382                 return err;
383
384         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
385                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
386                         continue;
387
388                 algname = aes_algs[i].base.cra_name + 2;
389                 drvname = aes_algs[i].base.cra_driver_name + 2;
390                 basename = aes_algs[i].base.cra_driver_name;
391                 simd = simd_skcipher_create_compat(algname, drvname, basename);
392                 err = PTR_ERR(simd);
393                 if (IS_ERR(simd))
394                         goto unregister_simds;
395
396                 aes_simd_algs[i] = simd;
397         }
398         return 0;
399
400 unregister_simds:
401         aes_exit();
402         return err;
403 }
404
405 module_init(aes_init);
406 module_exit(aes_exit);