Merge tag 'driver-core-5.5-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git...
[sfrench/cifs-2.6.git] / arch / s390 / crypto / aes_s390.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Cryptographic API.
4  *
5  * s390 implementation of the AES Cipher Algorithm.
6  *
7  * s390 Version:
8  *   Copyright IBM Corp. 2005, 2017
9  *   Author(s): Jan Glauber (jang@de.ibm.com)
10  *              Sebastian Siewior (sebastian@breakpoint.cc> SW-Fallback
11  *              Patrick Steuer <patrick.steuer@de.ibm.com>
12  *              Harald Freudenberger <freude@de.ibm.com>
13  *
14  * Derived from "crypto/aes_generic.c"
15  */
16
17 #define KMSG_COMPONENT "aes_s390"
18 #define pr_fmt(fmt) KMSG_COMPONENT ": " fmt
19
20 #include <crypto/aes.h>
21 #include <crypto/algapi.h>
22 #include <crypto/ghash.h>
23 #include <crypto/internal/aead.h>
24 #include <crypto/internal/skcipher.h>
25 #include <crypto/scatterwalk.h>
26 #include <linux/err.h>
27 #include <linux/module.h>
28 #include <linux/cpufeature.h>
29 #include <linux/init.h>
30 #include <linux/mutex.h>
31 #include <linux/fips.h>
32 #include <linux/string.h>
33 #include <crypto/xts.h>
34 #include <asm/cpacf.h>
35
36 static u8 *ctrblk;
37 static DEFINE_MUTEX(ctrblk_lock);
38
39 static cpacf_mask_t km_functions, kmc_functions, kmctr_functions,
40                     kma_functions;
41
42 struct s390_aes_ctx {
43         u8 key[AES_MAX_KEY_SIZE];
44         int key_len;
45         unsigned long fc;
46         union {
47                 struct crypto_skcipher *skcipher;
48                 struct crypto_cipher *cip;
49         } fallback;
50 };
51
52 struct s390_xts_ctx {
53         u8 key[32];
54         u8 pcc_key[32];
55         int key_len;
56         unsigned long fc;
57         struct crypto_skcipher *fallback;
58 };
59
60 struct gcm_sg_walk {
61         struct scatter_walk walk;
62         unsigned int walk_bytes;
63         u8 *walk_ptr;
64         unsigned int walk_bytes_remain;
65         u8 buf[AES_BLOCK_SIZE];
66         unsigned int buf_bytes;
67         u8 *ptr;
68         unsigned int nbytes;
69 };
70
71 static int setkey_fallback_cip(struct crypto_tfm *tfm, const u8 *in_key,
72                 unsigned int key_len)
73 {
74         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
75         int ret;
76
77         sctx->fallback.cip->base.crt_flags &= ~CRYPTO_TFM_REQ_MASK;
78         sctx->fallback.cip->base.crt_flags |= (tfm->crt_flags &
79                         CRYPTO_TFM_REQ_MASK);
80
81         ret = crypto_cipher_setkey(sctx->fallback.cip, in_key, key_len);
82         if (ret) {
83                 tfm->crt_flags &= ~CRYPTO_TFM_RES_MASK;
84                 tfm->crt_flags |= (sctx->fallback.cip->base.crt_flags &
85                                 CRYPTO_TFM_RES_MASK);
86         }
87         return ret;
88 }
89
90 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
91                        unsigned int key_len)
92 {
93         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
94         unsigned long fc;
95
96         /* Pick the correct function code based on the key length */
97         fc = (key_len == 16) ? CPACF_KM_AES_128 :
98              (key_len == 24) ? CPACF_KM_AES_192 :
99              (key_len == 32) ? CPACF_KM_AES_256 : 0;
100
101         /* Check if the function code is available */
102         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
103         if (!sctx->fc)
104                 return setkey_fallback_cip(tfm, in_key, key_len);
105
106         sctx->key_len = key_len;
107         memcpy(sctx->key, in_key, key_len);
108         return 0;
109 }
110
111 static void crypto_aes_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
112 {
113         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
114
115         if (unlikely(!sctx->fc)) {
116                 crypto_cipher_encrypt_one(sctx->fallback.cip, out, in);
117                 return;
118         }
119         cpacf_km(sctx->fc, &sctx->key, out, in, AES_BLOCK_SIZE);
120 }
121
122 static void crypto_aes_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
123 {
124         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
125
126         if (unlikely(!sctx->fc)) {
127                 crypto_cipher_decrypt_one(sctx->fallback.cip, out, in);
128                 return;
129         }
130         cpacf_km(sctx->fc | CPACF_DECRYPT,
131                  &sctx->key, out, in, AES_BLOCK_SIZE);
132 }
133
134 static int fallback_init_cip(struct crypto_tfm *tfm)
135 {
136         const char *name = tfm->__crt_alg->cra_name;
137         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
138
139         sctx->fallback.cip = crypto_alloc_cipher(name, 0,
140                                                  CRYPTO_ALG_NEED_FALLBACK);
141
142         if (IS_ERR(sctx->fallback.cip)) {
143                 pr_err("Allocating AES fallback algorithm %s failed\n",
144                        name);
145                 return PTR_ERR(sctx->fallback.cip);
146         }
147
148         return 0;
149 }
150
151 static void fallback_exit_cip(struct crypto_tfm *tfm)
152 {
153         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
154
155         crypto_free_cipher(sctx->fallback.cip);
156         sctx->fallback.cip = NULL;
157 }
158
159 static struct crypto_alg aes_alg = {
160         .cra_name               =       "aes",
161         .cra_driver_name        =       "aes-s390",
162         .cra_priority           =       300,
163         .cra_flags              =       CRYPTO_ALG_TYPE_CIPHER |
164                                         CRYPTO_ALG_NEED_FALLBACK,
165         .cra_blocksize          =       AES_BLOCK_SIZE,
166         .cra_ctxsize            =       sizeof(struct s390_aes_ctx),
167         .cra_module             =       THIS_MODULE,
168         .cra_init               =       fallback_init_cip,
169         .cra_exit               =       fallback_exit_cip,
170         .cra_u                  =       {
171                 .cipher = {
172                         .cia_min_keysize        =       AES_MIN_KEY_SIZE,
173                         .cia_max_keysize        =       AES_MAX_KEY_SIZE,
174                         .cia_setkey             =       aes_set_key,
175                         .cia_encrypt            =       crypto_aes_encrypt,
176                         .cia_decrypt            =       crypto_aes_decrypt,
177                 }
178         }
179 };
180
181 static int setkey_fallback_skcipher(struct crypto_skcipher *tfm, const u8 *key,
182                                     unsigned int len)
183 {
184         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
185         int ret;
186
187         crypto_skcipher_clear_flags(sctx->fallback.skcipher,
188                                     CRYPTO_TFM_REQ_MASK);
189         crypto_skcipher_set_flags(sctx->fallback.skcipher,
190                                   crypto_skcipher_get_flags(tfm) &
191                                   CRYPTO_TFM_REQ_MASK);
192         ret = crypto_skcipher_setkey(sctx->fallback.skcipher, key, len);
193         crypto_skcipher_set_flags(tfm,
194                                   crypto_skcipher_get_flags(sctx->fallback.skcipher) &
195                                   CRYPTO_TFM_RES_MASK);
196         return ret;
197 }
198
199 static int fallback_skcipher_crypt(struct s390_aes_ctx *sctx,
200                                    struct skcipher_request *req,
201                                    unsigned long modifier)
202 {
203         struct skcipher_request *subreq = skcipher_request_ctx(req);
204
205         *subreq = *req;
206         skcipher_request_set_tfm(subreq, sctx->fallback.skcipher);
207         return (modifier & CPACF_DECRYPT) ?
208                 crypto_skcipher_decrypt(subreq) :
209                 crypto_skcipher_encrypt(subreq);
210 }
211
212 static int ecb_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
213                            unsigned int key_len)
214 {
215         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
216         unsigned long fc;
217
218         /* Pick the correct function code based on the key length */
219         fc = (key_len == 16) ? CPACF_KM_AES_128 :
220              (key_len == 24) ? CPACF_KM_AES_192 :
221              (key_len == 32) ? CPACF_KM_AES_256 : 0;
222
223         /* Check if the function code is available */
224         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
225         if (!sctx->fc)
226                 return setkey_fallback_skcipher(tfm, in_key, key_len);
227
228         sctx->key_len = key_len;
229         memcpy(sctx->key, in_key, key_len);
230         return 0;
231 }
232
233 static int ecb_aes_crypt(struct skcipher_request *req, unsigned long modifier)
234 {
235         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
237         struct skcipher_walk walk;
238         unsigned int nbytes, n;
239         int ret;
240
241         if (unlikely(!sctx->fc))
242                 return fallback_skcipher_crypt(sctx, req, modifier);
243
244         ret = skcipher_walk_virt(&walk, req, false);
245         while ((nbytes = walk.nbytes) != 0) {
246                 /* only use complete blocks */
247                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
248                 cpacf_km(sctx->fc | modifier, sctx->key,
249                          walk.dst.virt.addr, walk.src.virt.addr, n);
250                 ret = skcipher_walk_done(&walk, nbytes - n);
251         }
252         return ret;
253 }
254
255 static int ecb_aes_encrypt(struct skcipher_request *req)
256 {
257         return ecb_aes_crypt(req, 0);
258 }
259
260 static int ecb_aes_decrypt(struct skcipher_request *req)
261 {
262         return ecb_aes_crypt(req, CPACF_DECRYPT);
263 }
264
265 static int fallback_init_skcipher(struct crypto_skcipher *tfm)
266 {
267         const char *name = crypto_tfm_alg_name(&tfm->base);
268         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
269
270         sctx->fallback.skcipher = crypto_alloc_skcipher(name, 0,
271                                 CRYPTO_ALG_NEED_FALLBACK | CRYPTO_ALG_ASYNC);
272
273         if (IS_ERR(sctx->fallback.skcipher)) {
274                 pr_err("Allocating AES fallback algorithm %s failed\n",
275                        name);
276                 return PTR_ERR(sctx->fallback.skcipher);
277         }
278
279         crypto_skcipher_set_reqsize(tfm, sizeof(struct skcipher_request) +
280                                     crypto_skcipher_reqsize(sctx->fallback.skcipher));
281         return 0;
282 }
283
284 static void fallback_exit_skcipher(struct crypto_skcipher *tfm)
285 {
286         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
287
288         crypto_free_skcipher(sctx->fallback.skcipher);
289 }
290
291 static struct skcipher_alg ecb_aes_alg = {
292         .base.cra_name          =       "ecb(aes)",
293         .base.cra_driver_name   =       "ecb-aes-s390",
294         .base.cra_priority      =       401,    /* combo: aes + ecb + 1 */
295         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
296         .base.cra_blocksize     =       AES_BLOCK_SIZE,
297         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
298         .base.cra_module        =       THIS_MODULE,
299         .init                   =       fallback_init_skcipher,
300         .exit                   =       fallback_exit_skcipher,
301         .min_keysize            =       AES_MIN_KEY_SIZE,
302         .max_keysize            =       AES_MAX_KEY_SIZE,
303         .setkey                 =       ecb_aes_set_key,
304         .encrypt                =       ecb_aes_encrypt,
305         .decrypt                =       ecb_aes_decrypt,
306 };
307
308 static int cbc_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
309                            unsigned int key_len)
310 {
311         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
312         unsigned long fc;
313
314         /* Pick the correct function code based on the key length */
315         fc = (key_len == 16) ? CPACF_KMC_AES_128 :
316              (key_len == 24) ? CPACF_KMC_AES_192 :
317              (key_len == 32) ? CPACF_KMC_AES_256 : 0;
318
319         /* Check if the function code is available */
320         sctx->fc = (fc && cpacf_test_func(&kmc_functions, fc)) ? fc : 0;
321         if (!sctx->fc)
322                 return setkey_fallback_skcipher(tfm, in_key, key_len);
323
324         sctx->key_len = key_len;
325         memcpy(sctx->key, in_key, key_len);
326         return 0;
327 }
328
329 static int cbc_aes_crypt(struct skcipher_request *req, unsigned long modifier)
330 {
331         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
332         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
333         struct skcipher_walk walk;
334         unsigned int nbytes, n;
335         int ret;
336         struct {
337                 u8 iv[AES_BLOCK_SIZE];
338                 u8 key[AES_MAX_KEY_SIZE];
339         } param;
340
341         if (unlikely(!sctx->fc))
342                 return fallback_skcipher_crypt(sctx, req, modifier);
343
344         ret = skcipher_walk_virt(&walk, req, false);
345         if (ret)
346                 return ret;
347         memcpy(param.iv, walk.iv, AES_BLOCK_SIZE);
348         memcpy(param.key, sctx->key, sctx->key_len);
349         while ((nbytes = walk.nbytes) != 0) {
350                 /* only use complete blocks */
351                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
352                 cpacf_kmc(sctx->fc | modifier, &param,
353                           walk.dst.virt.addr, walk.src.virt.addr, n);
354                 memcpy(walk.iv, param.iv, AES_BLOCK_SIZE);
355                 ret = skcipher_walk_done(&walk, nbytes - n);
356         }
357         return ret;
358 }
359
360 static int cbc_aes_encrypt(struct skcipher_request *req)
361 {
362         return cbc_aes_crypt(req, 0);
363 }
364
365 static int cbc_aes_decrypt(struct skcipher_request *req)
366 {
367         return cbc_aes_crypt(req, CPACF_DECRYPT);
368 }
369
370 static struct skcipher_alg cbc_aes_alg = {
371         .base.cra_name          =       "cbc(aes)",
372         .base.cra_driver_name   =       "cbc-aes-s390",
373         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
374         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
375         .base.cra_blocksize     =       AES_BLOCK_SIZE,
376         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
377         .base.cra_module        =       THIS_MODULE,
378         .init                   =       fallback_init_skcipher,
379         .exit                   =       fallback_exit_skcipher,
380         .min_keysize            =       AES_MIN_KEY_SIZE,
381         .max_keysize            =       AES_MAX_KEY_SIZE,
382         .ivsize                 =       AES_BLOCK_SIZE,
383         .setkey                 =       cbc_aes_set_key,
384         .encrypt                =       cbc_aes_encrypt,
385         .decrypt                =       cbc_aes_decrypt,
386 };
387
388 static int xts_fallback_setkey(struct crypto_skcipher *tfm, const u8 *key,
389                                unsigned int len)
390 {
391         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
392         int ret;
393
394         crypto_skcipher_clear_flags(xts_ctx->fallback, CRYPTO_TFM_REQ_MASK);
395         crypto_skcipher_set_flags(xts_ctx->fallback,
396                                   crypto_skcipher_get_flags(tfm) &
397                                   CRYPTO_TFM_REQ_MASK);
398         ret = crypto_skcipher_setkey(xts_ctx->fallback, key, len);
399         crypto_skcipher_set_flags(tfm,
400                                   crypto_skcipher_get_flags(xts_ctx->fallback) &
401                                   CRYPTO_TFM_RES_MASK);
402         return ret;
403 }
404
405 static int xts_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
406                            unsigned int key_len)
407 {
408         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
409         unsigned long fc;
410         int err;
411
412         err = xts_fallback_setkey(tfm, in_key, key_len);
413         if (err)
414                 return err;
415
416         /* In fips mode only 128 bit or 256 bit keys are valid */
417         if (fips_enabled && key_len != 32 && key_len != 64) {
418                 crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
419                 return -EINVAL;
420         }
421
422         /* Pick the correct function code based on the key length */
423         fc = (key_len == 32) ? CPACF_KM_XTS_128 :
424              (key_len == 64) ? CPACF_KM_XTS_256 : 0;
425
426         /* Check if the function code is available */
427         xts_ctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
428         if (!xts_ctx->fc)
429                 return 0;
430
431         /* Split the XTS key into the two subkeys */
432         key_len = key_len / 2;
433         xts_ctx->key_len = key_len;
434         memcpy(xts_ctx->key, in_key, key_len);
435         memcpy(xts_ctx->pcc_key, in_key + key_len, key_len);
436         return 0;
437 }
438
439 static int xts_aes_crypt(struct skcipher_request *req, unsigned long modifier)
440 {
441         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
442         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
443         struct skcipher_walk walk;
444         unsigned int offset, nbytes, n;
445         int ret;
446         struct {
447                 u8 key[32];
448                 u8 tweak[16];
449                 u8 block[16];
450                 u8 bit[16];
451                 u8 xts[16];
452         } pcc_param;
453         struct {
454                 u8 key[32];
455                 u8 init[16];
456         } xts_param;
457
458         if (req->cryptlen < AES_BLOCK_SIZE)
459                 return -EINVAL;
460
461         if (unlikely(!xts_ctx->fc || (req->cryptlen % AES_BLOCK_SIZE) != 0)) {
462                 struct skcipher_request *subreq = skcipher_request_ctx(req);
463
464                 *subreq = *req;
465                 skcipher_request_set_tfm(subreq, xts_ctx->fallback);
466                 return (modifier & CPACF_DECRYPT) ?
467                         crypto_skcipher_decrypt(subreq) :
468                         crypto_skcipher_encrypt(subreq);
469         }
470
471         ret = skcipher_walk_virt(&walk, req, false);
472         if (ret)
473                 return ret;
474         offset = xts_ctx->key_len & 0x10;
475         memset(pcc_param.block, 0, sizeof(pcc_param.block));
476         memset(pcc_param.bit, 0, sizeof(pcc_param.bit));
477         memset(pcc_param.xts, 0, sizeof(pcc_param.xts));
478         memcpy(pcc_param.tweak, walk.iv, sizeof(pcc_param.tweak));
479         memcpy(pcc_param.key + offset, xts_ctx->pcc_key, xts_ctx->key_len);
480         cpacf_pcc(xts_ctx->fc, pcc_param.key + offset);
481
482         memcpy(xts_param.key + offset, xts_ctx->key, xts_ctx->key_len);
483         memcpy(xts_param.init, pcc_param.xts, 16);
484
485         while ((nbytes = walk.nbytes) != 0) {
486                 /* only use complete blocks */
487                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
488                 cpacf_km(xts_ctx->fc | modifier, xts_param.key + offset,
489                          walk.dst.virt.addr, walk.src.virt.addr, n);
490                 ret = skcipher_walk_done(&walk, nbytes - n);
491         }
492         return ret;
493 }
494
495 static int xts_aes_encrypt(struct skcipher_request *req)
496 {
497         return xts_aes_crypt(req, 0);
498 }
499
500 static int xts_aes_decrypt(struct skcipher_request *req)
501 {
502         return xts_aes_crypt(req, CPACF_DECRYPT);
503 }
504
505 static int xts_fallback_init(struct crypto_skcipher *tfm)
506 {
507         const char *name = crypto_tfm_alg_name(&tfm->base);
508         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
509
510         xts_ctx->fallback = crypto_alloc_skcipher(name, 0,
511                                 CRYPTO_ALG_NEED_FALLBACK | CRYPTO_ALG_ASYNC);
512
513         if (IS_ERR(xts_ctx->fallback)) {
514                 pr_err("Allocating XTS fallback algorithm %s failed\n",
515                        name);
516                 return PTR_ERR(xts_ctx->fallback);
517         }
518         crypto_skcipher_set_reqsize(tfm, sizeof(struct skcipher_request) +
519                                     crypto_skcipher_reqsize(xts_ctx->fallback));
520         return 0;
521 }
522
523 static void xts_fallback_exit(struct crypto_skcipher *tfm)
524 {
525         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
526
527         crypto_free_skcipher(xts_ctx->fallback);
528 }
529
530 static struct skcipher_alg xts_aes_alg = {
531         .base.cra_name          =       "xts(aes)",
532         .base.cra_driver_name   =       "xts-aes-s390",
533         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
534         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
535         .base.cra_blocksize     =       AES_BLOCK_SIZE,
536         .base.cra_ctxsize       =       sizeof(struct s390_xts_ctx),
537         .base.cra_module        =       THIS_MODULE,
538         .init                   =       xts_fallback_init,
539         .exit                   =       xts_fallback_exit,
540         .min_keysize            =       2 * AES_MIN_KEY_SIZE,
541         .max_keysize            =       2 * AES_MAX_KEY_SIZE,
542         .ivsize                 =       AES_BLOCK_SIZE,
543         .setkey                 =       xts_aes_set_key,
544         .encrypt                =       xts_aes_encrypt,
545         .decrypt                =       xts_aes_decrypt,
546 };
547
548 static int ctr_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
549                            unsigned int key_len)
550 {
551         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
552         unsigned long fc;
553
554         /* Pick the correct function code based on the key length */
555         fc = (key_len == 16) ? CPACF_KMCTR_AES_128 :
556              (key_len == 24) ? CPACF_KMCTR_AES_192 :
557              (key_len == 32) ? CPACF_KMCTR_AES_256 : 0;
558
559         /* Check if the function code is available */
560         sctx->fc = (fc && cpacf_test_func(&kmctr_functions, fc)) ? fc : 0;
561         if (!sctx->fc)
562                 return setkey_fallback_skcipher(tfm, in_key, key_len);
563
564         sctx->key_len = key_len;
565         memcpy(sctx->key, in_key, key_len);
566         return 0;
567 }
568
569 static unsigned int __ctrblk_init(u8 *ctrptr, u8 *iv, unsigned int nbytes)
570 {
571         unsigned int i, n;
572
573         /* only use complete blocks, max. PAGE_SIZE */
574         memcpy(ctrptr, iv, AES_BLOCK_SIZE);
575         n = (nbytes > PAGE_SIZE) ? PAGE_SIZE : nbytes & ~(AES_BLOCK_SIZE - 1);
576         for (i = (n / AES_BLOCK_SIZE) - 1; i > 0; i--) {
577                 memcpy(ctrptr + AES_BLOCK_SIZE, ctrptr, AES_BLOCK_SIZE);
578                 crypto_inc(ctrptr + AES_BLOCK_SIZE, AES_BLOCK_SIZE);
579                 ctrptr += AES_BLOCK_SIZE;
580         }
581         return n;
582 }
583
584 static int ctr_aes_crypt(struct skcipher_request *req)
585 {
586         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
587         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
588         u8 buf[AES_BLOCK_SIZE], *ctrptr;
589         struct skcipher_walk walk;
590         unsigned int n, nbytes;
591         int ret, locked;
592
593         if (unlikely(!sctx->fc))
594                 return fallback_skcipher_crypt(sctx, req, 0);
595
596         locked = mutex_trylock(&ctrblk_lock);
597
598         ret = skcipher_walk_virt(&walk, req, false);
599         while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) {
600                 n = AES_BLOCK_SIZE;
601
602                 if (nbytes >= 2*AES_BLOCK_SIZE && locked)
603                         n = __ctrblk_init(ctrblk, walk.iv, nbytes);
604                 ctrptr = (n > AES_BLOCK_SIZE) ? ctrblk : walk.iv;
605                 cpacf_kmctr(sctx->fc, sctx->key, walk.dst.virt.addr,
606                             walk.src.virt.addr, n, ctrptr);
607                 if (ctrptr == ctrblk)
608                         memcpy(walk.iv, ctrptr + n - AES_BLOCK_SIZE,
609                                AES_BLOCK_SIZE);
610                 crypto_inc(walk.iv, AES_BLOCK_SIZE);
611                 ret = skcipher_walk_done(&walk, nbytes - n);
612         }
613         if (locked)
614                 mutex_unlock(&ctrblk_lock);
615         /*
616          * final block may be < AES_BLOCK_SIZE, copy only nbytes
617          */
618         if (nbytes) {
619                 cpacf_kmctr(sctx->fc, sctx->key, buf, walk.src.virt.addr,
620                             AES_BLOCK_SIZE, walk.iv);
621                 memcpy(walk.dst.virt.addr, buf, nbytes);
622                 crypto_inc(walk.iv, AES_BLOCK_SIZE);
623                 ret = skcipher_walk_done(&walk, 0);
624         }
625
626         return ret;
627 }
628
629 static struct skcipher_alg ctr_aes_alg = {
630         .base.cra_name          =       "ctr(aes)",
631         .base.cra_driver_name   =       "ctr-aes-s390",
632         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
633         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
634         .base.cra_blocksize     =       1,
635         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
636         .base.cra_module        =       THIS_MODULE,
637         .init                   =       fallback_init_skcipher,
638         .exit                   =       fallback_exit_skcipher,
639         .min_keysize            =       AES_MIN_KEY_SIZE,
640         .max_keysize            =       AES_MAX_KEY_SIZE,
641         .ivsize                 =       AES_BLOCK_SIZE,
642         .setkey                 =       ctr_aes_set_key,
643         .encrypt                =       ctr_aes_crypt,
644         .decrypt                =       ctr_aes_crypt,
645         .chunksize              =       AES_BLOCK_SIZE,
646 };
647
648 static int gcm_aes_setkey(struct crypto_aead *tfm, const u8 *key,
649                           unsigned int keylen)
650 {
651         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
652
653         switch (keylen) {
654         case AES_KEYSIZE_128:
655                 ctx->fc = CPACF_KMA_GCM_AES_128;
656                 break;
657         case AES_KEYSIZE_192:
658                 ctx->fc = CPACF_KMA_GCM_AES_192;
659                 break;
660         case AES_KEYSIZE_256:
661                 ctx->fc = CPACF_KMA_GCM_AES_256;
662                 break;
663         default:
664                 return -EINVAL;
665         }
666
667         memcpy(ctx->key, key, keylen);
668         ctx->key_len = keylen;
669         return 0;
670 }
671
672 static int gcm_aes_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
673 {
674         switch (authsize) {
675         case 4:
676         case 8:
677         case 12:
678         case 13:
679         case 14:
680         case 15:
681         case 16:
682                 break;
683         default:
684                 return -EINVAL;
685         }
686
687         return 0;
688 }
689
690 static void gcm_walk_start(struct gcm_sg_walk *gw, struct scatterlist *sg,
691                            unsigned int len)
692 {
693         memset(gw, 0, sizeof(*gw));
694         gw->walk_bytes_remain = len;
695         scatterwalk_start(&gw->walk, sg);
696 }
697
698 static inline unsigned int _gcm_sg_clamp_and_map(struct gcm_sg_walk *gw)
699 {
700         struct scatterlist *nextsg;
701
702         gw->walk_bytes = scatterwalk_clamp(&gw->walk, gw->walk_bytes_remain);
703         while (!gw->walk_bytes) {
704                 nextsg = sg_next(gw->walk.sg);
705                 if (!nextsg)
706                         return 0;
707                 scatterwalk_start(&gw->walk, nextsg);
708                 gw->walk_bytes = scatterwalk_clamp(&gw->walk,
709                                                    gw->walk_bytes_remain);
710         }
711         gw->walk_ptr = scatterwalk_map(&gw->walk);
712         return gw->walk_bytes;
713 }
714
715 static inline void _gcm_sg_unmap_and_advance(struct gcm_sg_walk *gw,
716                                              unsigned int nbytes)
717 {
718         gw->walk_bytes_remain -= nbytes;
719         scatterwalk_unmap(&gw->walk);
720         scatterwalk_advance(&gw->walk, nbytes);
721         scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
722         gw->walk_ptr = NULL;
723 }
724
725 static int gcm_in_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
726 {
727         int n;
728
729         if (gw->buf_bytes && gw->buf_bytes >= minbytesneeded) {
730                 gw->ptr = gw->buf;
731                 gw->nbytes = gw->buf_bytes;
732                 goto out;
733         }
734
735         if (gw->walk_bytes_remain == 0) {
736                 gw->ptr = NULL;
737                 gw->nbytes = 0;
738                 goto out;
739         }
740
741         if (!_gcm_sg_clamp_and_map(gw)) {
742                 gw->ptr = NULL;
743                 gw->nbytes = 0;
744                 goto out;
745         }
746
747         if (!gw->buf_bytes && gw->walk_bytes >= minbytesneeded) {
748                 gw->ptr = gw->walk_ptr;
749                 gw->nbytes = gw->walk_bytes;
750                 goto out;
751         }
752
753         while (1) {
754                 n = min(gw->walk_bytes, AES_BLOCK_SIZE - gw->buf_bytes);
755                 memcpy(gw->buf + gw->buf_bytes, gw->walk_ptr, n);
756                 gw->buf_bytes += n;
757                 _gcm_sg_unmap_and_advance(gw, n);
758                 if (gw->buf_bytes >= minbytesneeded) {
759                         gw->ptr = gw->buf;
760                         gw->nbytes = gw->buf_bytes;
761                         goto out;
762                 }
763                 if (!_gcm_sg_clamp_and_map(gw)) {
764                         gw->ptr = NULL;
765                         gw->nbytes = 0;
766                         goto out;
767                 }
768         }
769
770 out:
771         return gw->nbytes;
772 }
773
774 static int gcm_out_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
775 {
776         if (gw->walk_bytes_remain == 0) {
777                 gw->ptr = NULL;
778                 gw->nbytes = 0;
779                 goto out;
780         }
781
782         if (!_gcm_sg_clamp_and_map(gw)) {
783                 gw->ptr = NULL;
784                 gw->nbytes = 0;
785                 goto out;
786         }
787
788         if (gw->walk_bytes >= minbytesneeded) {
789                 gw->ptr = gw->walk_ptr;
790                 gw->nbytes = gw->walk_bytes;
791                 goto out;
792         }
793
794         scatterwalk_unmap(&gw->walk);
795         gw->walk_ptr = NULL;
796
797         gw->ptr = gw->buf;
798         gw->nbytes = sizeof(gw->buf);
799
800 out:
801         return gw->nbytes;
802 }
803
804 static int gcm_in_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
805 {
806         if (gw->ptr == NULL)
807                 return 0;
808
809         if (gw->ptr == gw->buf) {
810                 int n = gw->buf_bytes - bytesdone;
811                 if (n > 0) {
812                         memmove(gw->buf, gw->buf + bytesdone, n);
813                         gw->buf_bytes = n;
814                 } else
815                         gw->buf_bytes = 0;
816         } else
817                 _gcm_sg_unmap_and_advance(gw, bytesdone);
818
819         return bytesdone;
820 }
821
822 static int gcm_out_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
823 {
824         int i, n;
825
826         if (gw->ptr == NULL)
827                 return 0;
828
829         if (gw->ptr == gw->buf) {
830                 for (i = 0; i < bytesdone; i += n) {
831                         if (!_gcm_sg_clamp_and_map(gw))
832                                 return i;
833                         n = min(gw->walk_bytes, bytesdone - i);
834                         memcpy(gw->walk_ptr, gw->buf + i, n);
835                         _gcm_sg_unmap_and_advance(gw, n);
836                 }
837         } else
838                 _gcm_sg_unmap_and_advance(gw, bytesdone);
839
840         return bytesdone;
841 }
842
843 static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
844 {
845         struct crypto_aead *tfm = crypto_aead_reqtfm(req);
846         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
847         unsigned int ivsize = crypto_aead_ivsize(tfm);
848         unsigned int taglen = crypto_aead_authsize(tfm);
849         unsigned int aadlen = req->assoclen;
850         unsigned int pclen = req->cryptlen;
851         int ret = 0;
852
853         unsigned int n, len, in_bytes, out_bytes,
854                      min_bytes, bytes, aad_bytes, pc_bytes;
855         struct gcm_sg_walk gw_in, gw_out;
856         u8 tag[GHASH_DIGEST_SIZE];
857
858         struct {
859                 u32 _[3];               /* reserved */
860                 u32 cv;                 /* Counter Value */
861                 u8 t[GHASH_DIGEST_SIZE];/* Tag */
862                 u8 h[AES_BLOCK_SIZE];   /* Hash-subkey */
863                 u64 taadl;              /* Total AAD Length */
864                 u64 tpcl;               /* Total Plain-/Cipher-text Length */
865                 u8 j0[GHASH_BLOCK_SIZE];/* initial counter value */
866                 u8 k[AES_MAX_KEY_SIZE]; /* Key */
867         } param;
868
869         /*
870          * encrypt
871          *   req->src: aad||plaintext
872          *   req->dst: aad||ciphertext||tag
873          * decrypt
874          *   req->src: aad||ciphertext||tag
875          *   req->dst: aad||plaintext, return 0 or -EBADMSG
876          * aad, plaintext and ciphertext may be empty.
877          */
878         if (flags & CPACF_DECRYPT)
879                 pclen -= taglen;
880         len = aadlen + pclen;
881
882         memset(&param, 0, sizeof(param));
883         param.cv = 1;
884         param.taadl = aadlen * 8;
885         param.tpcl = pclen * 8;
886         memcpy(param.j0, req->iv, ivsize);
887         *(u32 *)(param.j0 + ivsize) = 1;
888         memcpy(param.k, ctx->key, ctx->key_len);
889
890         gcm_walk_start(&gw_in, req->src, len);
891         gcm_walk_start(&gw_out, req->dst, len);
892
893         do {
894                 min_bytes = min_t(unsigned int,
895                                   aadlen > 0 ? aadlen : pclen, AES_BLOCK_SIZE);
896                 in_bytes = gcm_in_walk_go(&gw_in, min_bytes);
897                 out_bytes = gcm_out_walk_go(&gw_out, min_bytes);
898                 bytes = min(in_bytes, out_bytes);
899
900                 if (aadlen + pclen <= bytes) {
901                         aad_bytes = aadlen;
902                         pc_bytes = pclen;
903                         flags |= CPACF_KMA_LAAD | CPACF_KMA_LPC;
904                 } else {
905                         if (aadlen <= bytes) {
906                                 aad_bytes = aadlen;
907                                 pc_bytes = (bytes - aadlen) &
908                                            ~(AES_BLOCK_SIZE - 1);
909                                 flags |= CPACF_KMA_LAAD;
910                         } else {
911                                 aad_bytes = bytes & ~(AES_BLOCK_SIZE - 1);
912                                 pc_bytes = 0;
913                         }
914                 }
915
916                 if (aad_bytes > 0)
917                         memcpy(gw_out.ptr, gw_in.ptr, aad_bytes);
918
919                 cpacf_kma(ctx->fc | flags, &param,
920                           gw_out.ptr + aad_bytes,
921                           gw_in.ptr + aad_bytes, pc_bytes,
922                           gw_in.ptr, aad_bytes);
923
924                 n = aad_bytes + pc_bytes;
925                 if (gcm_in_walk_done(&gw_in, n) != n)
926                         return -ENOMEM;
927                 if (gcm_out_walk_done(&gw_out, n) != n)
928                         return -ENOMEM;
929                 aadlen -= aad_bytes;
930                 pclen -= pc_bytes;
931         } while (aadlen + pclen > 0);
932
933         if (flags & CPACF_DECRYPT) {
934                 scatterwalk_map_and_copy(tag, req->src, len, taglen, 0);
935                 if (crypto_memneq(tag, param.t, taglen))
936                         ret = -EBADMSG;
937         } else
938                 scatterwalk_map_and_copy(param.t, req->dst, len, taglen, 1);
939
940         memzero_explicit(&param, sizeof(param));
941         return ret;
942 }
943
944 static int gcm_aes_encrypt(struct aead_request *req)
945 {
946         return gcm_aes_crypt(req, CPACF_ENCRYPT);
947 }
948
949 static int gcm_aes_decrypt(struct aead_request *req)
950 {
951         return gcm_aes_crypt(req, CPACF_DECRYPT);
952 }
953
954 static struct aead_alg gcm_aes_aead = {
955         .setkey                 = gcm_aes_setkey,
956         .setauthsize            = gcm_aes_setauthsize,
957         .encrypt                = gcm_aes_encrypt,
958         .decrypt                = gcm_aes_decrypt,
959
960         .ivsize                 = GHASH_BLOCK_SIZE - sizeof(u32),
961         .maxauthsize            = GHASH_DIGEST_SIZE,
962         .chunksize              = AES_BLOCK_SIZE,
963
964         .base                   = {
965                 .cra_blocksize          = 1,
966                 .cra_ctxsize            = sizeof(struct s390_aes_ctx),
967                 .cra_priority           = 900,
968                 .cra_name               = "gcm(aes)",
969                 .cra_driver_name        = "gcm-aes-s390",
970                 .cra_module             = THIS_MODULE,
971         },
972 };
973
974 static struct crypto_alg *aes_s390_alg;
975 static struct skcipher_alg *aes_s390_skcipher_algs[4];
976 static int aes_s390_skciphers_num;
977 static struct aead_alg *aes_s390_aead_alg;
978
979 static int aes_s390_register_skcipher(struct skcipher_alg *alg)
980 {
981         int ret;
982
983         ret = crypto_register_skcipher(alg);
984         if (!ret)
985                 aes_s390_skcipher_algs[aes_s390_skciphers_num++] = alg;
986         return ret;
987 }
988
989 static void aes_s390_fini(void)
990 {
991         if (aes_s390_alg)
992                 crypto_unregister_alg(aes_s390_alg);
993         while (aes_s390_skciphers_num--)
994                 crypto_unregister_skcipher(aes_s390_skcipher_algs[aes_s390_skciphers_num]);
995         if (ctrblk)
996                 free_page((unsigned long) ctrblk);
997
998         if (aes_s390_aead_alg)
999                 crypto_unregister_aead(aes_s390_aead_alg);
1000 }
1001
1002 static int __init aes_s390_init(void)
1003 {
1004         int ret;
1005
1006         /* Query available functions for KM, KMC, KMCTR and KMA */
1007         cpacf_query(CPACF_KM, &km_functions);
1008         cpacf_query(CPACF_KMC, &kmc_functions);
1009         cpacf_query(CPACF_KMCTR, &kmctr_functions);
1010         cpacf_query(CPACF_KMA, &kma_functions);
1011
1012         if (cpacf_test_func(&km_functions, CPACF_KM_AES_128) ||
1013             cpacf_test_func(&km_functions, CPACF_KM_AES_192) ||
1014             cpacf_test_func(&km_functions, CPACF_KM_AES_256)) {
1015                 ret = crypto_register_alg(&aes_alg);
1016                 if (ret)
1017                         goto out_err;
1018                 aes_s390_alg = &aes_alg;
1019                 ret = aes_s390_register_skcipher(&ecb_aes_alg);
1020                 if (ret)
1021                         goto out_err;
1022         }
1023
1024         if (cpacf_test_func(&kmc_functions, CPACF_KMC_AES_128) ||
1025             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_192) ||
1026             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_256)) {
1027                 ret = aes_s390_register_skcipher(&cbc_aes_alg);
1028                 if (ret)
1029                         goto out_err;
1030         }
1031
1032         if (cpacf_test_func(&km_functions, CPACF_KM_XTS_128) ||
1033             cpacf_test_func(&km_functions, CPACF_KM_XTS_256)) {
1034                 ret = aes_s390_register_skcipher(&xts_aes_alg);
1035                 if (ret)
1036                         goto out_err;
1037         }
1038
1039         if (cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_128) ||
1040             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_192) ||
1041             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_256)) {
1042                 ctrblk = (u8 *) __get_free_page(GFP_KERNEL);
1043                 if (!ctrblk) {
1044                         ret = -ENOMEM;
1045                         goto out_err;
1046                 }
1047                 ret = aes_s390_register_skcipher(&ctr_aes_alg);
1048                 if (ret)
1049                         goto out_err;
1050         }
1051
1052         if (cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_128) ||
1053             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_192) ||
1054             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_256)) {
1055                 ret = crypto_register_aead(&gcm_aes_aead);
1056                 if (ret)
1057                         goto out_err;
1058                 aes_s390_aead_alg = &gcm_aes_aead;
1059         }
1060
1061         return 0;
1062 out_err:
1063         aes_s390_fini();
1064         return ret;
1065 }
1066
1067 module_cpu_feature_match(MSA, aes_s390_init);
1068 module_exit(aes_s390_fini);
1069
1070 MODULE_ALIAS_CRYPTO("aes-all");
1071
1072 MODULE_DESCRIPTION("Rijndael (AES) Cipher Algorithm");
1073 MODULE_LICENSE("GPL");