Merge tag 'scsi-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/jejb/scsi
[sfrench/cifs-2.6.git] / arch / s390 / crypto / aes_s390.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Cryptographic API.
4  *
5  * s390 implementation of the AES Cipher Algorithm.
6  *
7  * s390 Version:
8  *   Copyright IBM Corp. 2005, 2017
9  *   Author(s): Jan Glauber (jang@de.ibm.com)
10  *              Sebastian Siewior (sebastian@breakpoint.cc> SW-Fallback
11  *              Patrick Steuer <patrick.steuer@de.ibm.com>
12  *              Harald Freudenberger <freude@de.ibm.com>
13  *
14  * Derived from "crypto/aes_generic.c"
15  */
16
17 #define KMSG_COMPONENT "aes_s390"
18 #define pr_fmt(fmt) KMSG_COMPONENT ": " fmt
19
20 #include <crypto/aes.h>
21 #include <crypto/algapi.h>
22 #include <crypto/ghash.h>
23 #include <crypto/internal/aead.h>
24 #include <crypto/internal/cipher.h>
25 #include <crypto/internal/skcipher.h>
26 #include <crypto/scatterwalk.h>
27 #include <linux/err.h>
28 #include <linux/module.h>
29 #include <linux/cpufeature.h>
30 #include <linux/init.h>
31 #include <linux/mutex.h>
32 #include <linux/fips.h>
33 #include <linux/string.h>
34 #include <crypto/xts.h>
35 #include <asm/cpacf.h>
36
37 static u8 *ctrblk;
38 static DEFINE_MUTEX(ctrblk_lock);
39
40 static cpacf_mask_t km_functions, kmc_functions, kmctr_functions,
41                     kma_functions;
42
43 struct s390_aes_ctx {
44         u8 key[AES_MAX_KEY_SIZE];
45         int key_len;
46         unsigned long fc;
47         union {
48                 struct crypto_skcipher *skcipher;
49                 struct crypto_cipher *cip;
50         } fallback;
51 };
52
53 struct s390_xts_ctx {
54         u8 key[32];
55         u8 pcc_key[32];
56         int key_len;
57         unsigned long fc;
58         struct crypto_skcipher *fallback;
59 };
60
61 struct gcm_sg_walk {
62         struct scatter_walk walk;
63         unsigned int walk_bytes;
64         u8 *walk_ptr;
65         unsigned int walk_bytes_remain;
66         u8 buf[AES_BLOCK_SIZE];
67         unsigned int buf_bytes;
68         u8 *ptr;
69         unsigned int nbytes;
70 };
71
72 static int setkey_fallback_cip(struct crypto_tfm *tfm, const u8 *in_key,
73                 unsigned int key_len)
74 {
75         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
76
77         sctx->fallback.cip->base.crt_flags &= ~CRYPTO_TFM_REQ_MASK;
78         sctx->fallback.cip->base.crt_flags |= (tfm->crt_flags &
79                         CRYPTO_TFM_REQ_MASK);
80
81         return crypto_cipher_setkey(sctx->fallback.cip, in_key, key_len);
82 }
83
84 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
85                        unsigned int key_len)
86 {
87         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
88         unsigned long fc;
89
90         /* Pick the correct function code based on the key length */
91         fc = (key_len == 16) ? CPACF_KM_AES_128 :
92              (key_len == 24) ? CPACF_KM_AES_192 :
93              (key_len == 32) ? CPACF_KM_AES_256 : 0;
94
95         /* Check if the function code is available */
96         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
97         if (!sctx->fc)
98                 return setkey_fallback_cip(tfm, in_key, key_len);
99
100         sctx->key_len = key_len;
101         memcpy(sctx->key, in_key, key_len);
102         return 0;
103 }
104
105 static void crypto_aes_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
106 {
107         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
108
109         if (unlikely(!sctx->fc)) {
110                 crypto_cipher_encrypt_one(sctx->fallback.cip, out, in);
111                 return;
112         }
113         cpacf_km(sctx->fc, &sctx->key, out, in, AES_BLOCK_SIZE);
114 }
115
116 static void crypto_aes_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
117 {
118         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
119
120         if (unlikely(!sctx->fc)) {
121                 crypto_cipher_decrypt_one(sctx->fallback.cip, out, in);
122                 return;
123         }
124         cpacf_km(sctx->fc | CPACF_DECRYPT,
125                  &sctx->key, out, in, AES_BLOCK_SIZE);
126 }
127
128 static int fallback_init_cip(struct crypto_tfm *tfm)
129 {
130         const char *name = tfm->__crt_alg->cra_name;
131         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
132
133         sctx->fallback.cip = crypto_alloc_cipher(name, 0,
134                                                  CRYPTO_ALG_NEED_FALLBACK);
135
136         if (IS_ERR(sctx->fallback.cip)) {
137                 pr_err("Allocating AES fallback algorithm %s failed\n",
138                        name);
139                 return PTR_ERR(sctx->fallback.cip);
140         }
141
142         return 0;
143 }
144
145 static void fallback_exit_cip(struct crypto_tfm *tfm)
146 {
147         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
148
149         crypto_free_cipher(sctx->fallback.cip);
150         sctx->fallback.cip = NULL;
151 }
152
153 static struct crypto_alg aes_alg = {
154         .cra_name               =       "aes",
155         .cra_driver_name        =       "aes-s390",
156         .cra_priority           =       300,
157         .cra_flags              =       CRYPTO_ALG_TYPE_CIPHER |
158                                         CRYPTO_ALG_NEED_FALLBACK,
159         .cra_blocksize          =       AES_BLOCK_SIZE,
160         .cra_ctxsize            =       sizeof(struct s390_aes_ctx),
161         .cra_module             =       THIS_MODULE,
162         .cra_init               =       fallback_init_cip,
163         .cra_exit               =       fallback_exit_cip,
164         .cra_u                  =       {
165                 .cipher = {
166                         .cia_min_keysize        =       AES_MIN_KEY_SIZE,
167                         .cia_max_keysize        =       AES_MAX_KEY_SIZE,
168                         .cia_setkey             =       aes_set_key,
169                         .cia_encrypt            =       crypto_aes_encrypt,
170                         .cia_decrypt            =       crypto_aes_decrypt,
171                 }
172         }
173 };
174
175 static int setkey_fallback_skcipher(struct crypto_skcipher *tfm, const u8 *key,
176                                     unsigned int len)
177 {
178         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
179
180         crypto_skcipher_clear_flags(sctx->fallback.skcipher,
181                                     CRYPTO_TFM_REQ_MASK);
182         crypto_skcipher_set_flags(sctx->fallback.skcipher,
183                                   crypto_skcipher_get_flags(tfm) &
184                                   CRYPTO_TFM_REQ_MASK);
185         return crypto_skcipher_setkey(sctx->fallback.skcipher, key, len);
186 }
187
188 static int fallback_skcipher_crypt(struct s390_aes_ctx *sctx,
189                                    struct skcipher_request *req,
190                                    unsigned long modifier)
191 {
192         struct skcipher_request *subreq = skcipher_request_ctx(req);
193
194         *subreq = *req;
195         skcipher_request_set_tfm(subreq, sctx->fallback.skcipher);
196         return (modifier & CPACF_DECRYPT) ?
197                 crypto_skcipher_decrypt(subreq) :
198                 crypto_skcipher_encrypt(subreq);
199 }
200
201 static int ecb_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
202                            unsigned int key_len)
203 {
204         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
205         unsigned long fc;
206
207         /* Pick the correct function code based on the key length */
208         fc = (key_len == 16) ? CPACF_KM_AES_128 :
209              (key_len == 24) ? CPACF_KM_AES_192 :
210              (key_len == 32) ? CPACF_KM_AES_256 : 0;
211
212         /* Check if the function code is available */
213         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
214         if (!sctx->fc)
215                 return setkey_fallback_skcipher(tfm, in_key, key_len);
216
217         sctx->key_len = key_len;
218         memcpy(sctx->key, in_key, key_len);
219         return 0;
220 }
221
222 static int ecb_aes_crypt(struct skcipher_request *req, unsigned long modifier)
223 {
224         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
225         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
226         struct skcipher_walk walk;
227         unsigned int nbytes, n;
228         int ret;
229
230         if (unlikely(!sctx->fc))
231                 return fallback_skcipher_crypt(sctx, req, modifier);
232
233         ret = skcipher_walk_virt(&walk, req, false);
234         while ((nbytes = walk.nbytes) != 0) {
235                 /* only use complete blocks */
236                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
237                 cpacf_km(sctx->fc | modifier, sctx->key,
238                          walk.dst.virt.addr, walk.src.virt.addr, n);
239                 ret = skcipher_walk_done(&walk, nbytes - n);
240         }
241         return ret;
242 }
243
244 static int ecb_aes_encrypt(struct skcipher_request *req)
245 {
246         return ecb_aes_crypt(req, 0);
247 }
248
249 static int ecb_aes_decrypt(struct skcipher_request *req)
250 {
251         return ecb_aes_crypt(req, CPACF_DECRYPT);
252 }
253
254 static int fallback_init_skcipher(struct crypto_skcipher *tfm)
255 {
256         const char *name = crypto_tfm_alg_name(&tfm->base);
257         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
258
259         sctx->fallback.skcipher = crypto_alloc_skcipher(name, 0,
260                                 CRYPTO_ALG_NEED_FALLBACK | CRYPTO_ALG_ASYNC);
261
262         if (IS_ERR(sctx->fallback.skcipher)) {
263                 pr_err("Allocating AES fallback algorithm %s failed\n",
264                        name);
265                 return PTR_ERR(sctx->fallback.skcipher);
266         }
267
268         crypto_skcipher_set_reqsize(tfm, sizeof(struct skcipher_request) +
269                                     crypto_skcipher_reqsize(sctx->fallback.skcipher));
270         return 0;
271 }
272
273 static void fallback_exit_skcipher(struct crypto_skcipher *tfm)
274 {
275         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
276
277         crypto_free_skcipher(sctx->fallback.skcipher);
278 }
279
280 static struct skcipher_alg ecb_aes_alg = {
281         .base.cra_name          =       "ecb(aes)",
282         .base.cra_driver_name   =       "ecb-aes-s390",
283         .base.cra_priority      =       401,    /* combo: aes + ecb + 1 */
284         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
285         .base.cra_blocksize     =       AES_BLOCK_SIZE,
286         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
287         .base.cra_module        =       THIS_MODULE,
288         .init                   =       fallback_init_skcipher,
289         .exit                   =       fallback_exit_skcipher,
290         .min_keysize            =       AES_MIN_KEY_SIZE,
291         .max_keysize            =       AES_MAX_KEY_SIZE,
292         .setkey                 =       ecb_aes_set_key,
293         .encrypt                =       ecb_aes_encrypt,
294         .decrypt                =       ecb_aes_decrypt,
295 };
296
297 static int cbc_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
298                            unsigned int key_len)
299 {
300         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
301         unsigned long fc;
302
303         /* Pick the correct function code based on the key length */
304         fc = (key_len == 16) ? CPACF_KMC_AES_128 :
305              (key_len == 24) ? CPACF_KMC_AES_192 :
306              (key_len == 32) ? CPACF_KMC_AES_256 : 0;
307
308         /* Check if the function code is available */
309         sctx->fc = (fc && cpacf_test_func(&kmc_functions, fc)) ? fc : 0;
310         if (!sctx->fc)
311                 return setkey_fallback_skcipher(tfm, in_key, key_len);
312
313         sctx->key_len = key_len;
314         memcpy(sctx->key, in_key, key_len);
315         return 0;
316 }
317
318 static int cbc_aes_crypt(struct skcipher_request *req, unsigned long modifier)
319 {
320         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
321         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
322         struct skcipher_walk walk;
323         unsigned int nbytes, n;
324         int ret;
325         struct {
326                 u8 iv[AES_BLOCK_SIZE];
327                 u8 key[AES_MAX_KEY_SIZE];
328         } param;
329
330         if (unlikely(!sctx->fc))
331                 return fallback_skcipher_crypt(sctx, req, modifier);
332
333         ret = skcipher_walk_virt(&walk, req, false);
334         if (ret)
335                 return ret;
336         memcpy(param.iv, walk.iv, AES_BLOCK_SIZE);
337         memcpy(param.key, sctx->key, sctx->key_len);
338         while ((nbytes = walk.nbytes) != 0) {
339                 /* only use complete blocks */
340                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
341                 cpacf_kmc(sctx->fc | modifier, &param,
342                           walk.dst.virt.addr, walk.src.virt.addr, n);
343                 memcpy(walk.iv, param.iv, AES_BLOCK_SIZE);
344                 ret = skcipher_walk_done(&walk, nbytes - n);
345         }
346         memzero_explicit(&param, sizeof(param));
347         return ret;
348 }
349
350 static int cbc_aes_encrypt(struct skcipher_request *req)
351 {
352         return cbc_aes_crypt(req, 0);
353 }
354
355 static int cbc_aes_decrypt(struct skcipher_request *req)
356 {
357         return cbc_aes_crypt(req, CPACF_DECRYPT);
358 }
359
360 static struct skcipher_alg cbc_aes_alg = {
361         .base.cra_name          =       "cbc(aes)",
362         .base.cra_driver_name   =       "cbc-aes-s390",
363         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
364         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
365         .base.cra_blocksize     =       AES_BLOCK_SIZE,
366         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
367         .base.cra_module        =       THIS_MODULE,
368         .init                   =       fallback_init_skcipher,
369         .exit                   =       fallback_exit_skcipher,
370         .min_keysize            =       AES_MIN_KEY_SIZE,
371         .max_keysize            =       AES_MAX_KEY_SIZE,
372         .ivsize                 =       AES_BLOCK_SIZE,
373         .setkey                 =       cbc_aes_set_key,
374         .encrypt                =       cbc_aes_encrypt,
375         .decrypt                =       cbc_aes_decrypt,
376 };
377
378 static int xts_fallback_setkey(struct crypto_skcipher *tfm, const u8 *key,
379                                unsigned int len)
380 {
381         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
382
383         crypto_skcipher_clear_flags(xts_ctx->fallback, CRYPTO_TFM_REQ_MASK);
384         crypto_skcipher_set_flags(xts_ctx->fallback,
385                                   crypto_skcipher_get_flags(tfm) &
386                                   CRYPTO_TFM_REQ_MASK);
387         return crypto_skcipher_setkey(xts_ctx->fallback, key, len);
388 }
389
390 static int xts_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
391                            unsigned int key_len)
392 {
393         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
394         unsigned long fc;
395         int err;
396
397         err = xts_fallback_setkey(tfm, in_key, key_len);
398         if (err)
399                 return err;
400
401         /* Pick the correct function code based on the key length */
402         fc = (key_len == 32) ? CPACF_KM_XTS_128 :
403              (key_len == 64) ? CPACF_KM_XTS_256 : 0;
404
405         /* Check if the function code is available */
406         xts_ctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
407         if (!xts_ctx->fc)
408                 return 0;
409
410         /* Split the XTS key into the two subkeys */
411         key_len = key_len / 2;
412         xts_ctx->key_len = key_len;
413         memcpy(xts_ctx->key, in_key, key_len);
414         memcpy(xts_ctx->pcc_key, in_key + key_len, key_len);
415         return 0;
416 }
417
418 static int xts_aes_crypt(struct skcipher_request *req, unsigned long modifier)
419 {
420         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
421         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
422         struct skcipher_walk walk;
423         unsigned int offset, nbytes, n;
424         int ret;
425         struct {
426                 u8 key[32];
427                 u8 tweak[16];
428                 u8 block[16];
429                 u8 bit[16];
430                 u8 xts[16];
431         } pcc_param;
432         struct {
433                 u8 key[32];
434                 u8 init[16];
435         } xts_param;
436
437         if (req->cryptlen < AES_BLOCK_SIZE)
438                 return -EINVAL;
439
440         if (unlikely(!xts_ctx->fc || (req->cryptlen % AES_BLOCK_SIZE) != 0)) {
441                 struct skcipher_request *subreq = skcipher_request_ctx(req);
442
443                 *subreq = *req;
444                 skcipher_request_set_tfm(subreq, xts_ctx->fallback);
445                 return (modifier & CPACF_DECRYPT) ?
446                         crypto_skcipher_decrypt(subreq) :
447                         crypto_skcipher_encrypt(subreq);
448         }
449
450         ret = skcipher_walk_virt(&walk, req, false);
451         if (ret)
452                 return ret;
453         offset = xts_ctx->key_len & 0x10;
454         memset(pcc_param.block, 0, sizeof(pcc_param.block));
455         memset(pcc_param.bit, 0, sizeof(pcc_param.bit));
456         memset(pcc_param.xts, 0, sizeof(pcc_param.xts));
457         memcpy(pcc_param.tweak, walk.iv, sizeof(pcc_param.tweak));
458         memcpy(pcc_param.key + offset, xts_ctx->pcc_key, xts_ctx->key_len);
459         cpacf_pcc(xts_ctx->fc, pcc_param.key + offset);
460
461         memcpy(xts_param.key + offset, xts_ctx->key, xts_ctx->key_len);
462         memcpy(xts_param.init, pcc_param.xts, 16);
463
464         while ((nbytes = walk.nbytes) != 0) {
465                 /* only use complete blocks */
466                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
467                 cpacf_km(xts_ctx->fc | modifier, xts_param.key + offset,
468                          walk.dst.virt.addr, walk.src.virt.addr, n);
469                 ret = skcipher_walk_done(&walk, nbytes - n);
470         }
471         memzero_explicit(&pcc_param, sizeof(pcc_param));
472         memzero_explicit(&xts_param, sizeof(xts_param));
473         return ret;
474 }
475
476 static int xts_aes_encrypt(struct skcipher_request *req)
477 {
478         return xts_aes_crypt(req, 0);
479 }
480
481 static int xts_aes_decrypt(struct skcipher_request *req)
482 {
483         return xts_aes_crypt(req, CPACF_DECRYPT);
484 }
485
486 static int xts_fallback_init(struct crypto_skcipher *tfm)
487 {
488         const char *name = crypto_tfm_alg_name(&tfm->base);
489         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
490
491         xts_ctx->fallback = crypto_alloc_skcipher(name, 0,
492                                 CRYPTO_ALG_NEED_FALLBACK | CRYPTO_ALG_ASYNC);
493
494         if (IS_ERR(xts_ctx->fallback)) {
495                 pr_err("Allocating XTS fallback algorithm %s failed\n",
496                        name);
497                 return PTR_ERR(xts_ctx->fallback);
498         }
499         crypto_skcipher_set_reqsize(tfm, sizeof(struct skcipher_request) +
500                                     crypto_skcipher_reqsize(xts_ctx->fallback));
501         return 0;
502 }
503
504 static void xts_fallback_exit(struct crypto_skcipher *tfm)
505 {
506         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
507
508         crypto_free_skcipher(xts_ctx->fallback);
509 }
510
511 static struct skcipher_alg xts_aes_alg = {
512         .base.cra_name          =       "xts(aes)",
513         .base.cra_driver_name   =       "xts-aes-s390",
514         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
515         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
516         .base.cra_blocksize     =       AES_BLOCK_SIZE,
517         .base.cra_ctxsize       =       sizeof(struct s390_xts_ctx),
518         .base.cra_module        =       THIS_MODULE,
519         .init                   =       xts_fallback_init,
520         .exit                   =       xts_fallback_exit,
521         .min_keysize            =       2 * AES_MIN_KEY_SIZE,
522         .max_keysize            =       2 * AES_MAX_KEY_SIZE,
523         .ivsize                 =       AES_BLOCK_SIZE,
524         .setkey                 =       xts_aes_set_key,
525         .encrypt                =       xts_aes_encrypt,
526         .decrypt                =       xts_aes_decrypt,
527 };
528
529 static int ctr_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
530                            unsigned int key_len)
531 {
532         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
533         unsigned long fc;
534
535         /* Pick the correct function code based on the key length */
536         fc = (key_len == 16) ? CPACF_KMCTR_AES_128 :
537              (key_len == 24) ? CPACF_KMCTR_AES_192 :
538              (key_len == 32) ? CPACF_KMCTR_AES_256 : 0;
539
540         /* Check if the function code is available */
541         sctx->fc = (fc && cpacf_test_func(&kmctr_functions, fc)) ? fc : 0;
542         if (!sctx->fc)
543                 return setkey_fallback_skcipher(tfm, in_key, key_len);
544
545         sctx->key_len = key_len;
546         memcpy(sctx->key, in_key, key_len);
547         return 0;
548 }
549
550 static unsigned int __ctrblk_init(u8 *ctrptr, u8 *iv, unsigned int nbytes)
551 {
552         unsigned int i, n;
553
554         /* only use complete blocks, max. PAGE_SIZE */
555         memcpy(ctrptr, iv, AES_BLOCK_SIZE);
556         n = (nbytes > PAGE_SIZE) ? PAGE_SIZE : nbytes & ~(AES_BLOCK_SIZE - 1);
557         for (i = (n / AES_BLOCK_SIZE) - 1; i > 0; i--) {
558                 memcpy(ctrptr + AES_BLOCK_SIZE, ctrptr, AES_BLOCK_SIZE);
559                 crypto_inc(ctrptr + AES_BLOCK_SIZE, AES_BLOCK_SIZE);
560                 ctrptr += AES_BLOCK_SIZE;
561         }
562         return n;
563 }
564
565 static int ctr_aes_crypt(struct skcipher_request *req)
566 {
567         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
568         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
569         u8 buf[AES_BLOCK_SIZE], *ctrptr;
570         struct skcipher_walk walk;
571         unsigned int n, nbytes;
572         int ret, locked;
573
574         if (unlikely(!sctx->fc))
575                 return fallback_skcipher_crypt(sctx, req, 0);
576
577         locked = mutex_trylock(&ctrblk_lock);
578
579         ret = skcipher_walk_virt(&walk, req, false);
580         while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) {
581                 n = AES_BLOCK_SIZE;
582
583                 if (nbytes >= 2*AES_BLOCK_SIZE && locked)
584                         n = __ctrblk_init(ctrblk, walk.iv, nbytes);
585                 ctrptr = (n > AES_BLOCK_SIZE) ? ctrblk : walk.iv;
586                 cpacf_kmctr(sctx->fc, sctx->key, walk.dst.virt.addr,
587                             walk.src.virt.addr, n, ctrptr);
588                 if (ctrptr == ctrblk)
589                         memcpy(walk.iv, ctrptr + n - AES_BLOCK_SIZE,
590                                AES_BLOCK_SIZE);
591                 crypto_inc(walk.iv, AES_BLOCK_SIZE);
592                 ret = skcipher_walk_done(&walk, nbytes - n);
593         }
594         if (locked)
595                 mutex_unlock(&ctrblk_lock);
596         /*
597          * final block may be < AES_BLOCK_SIZE, copy only nbytes
598          */
599         if (nbytes) {
600                 memset(buf, 0, AES_BLOCK_SIZE);
601                 memcpy(buf, walk.src.virt.addr, nbytes);
602                 cpacf_kmctr(sctx->fc, sctx->key, buf, buf,
603                             AES_BLOCK_SIZE, walk.iv);
604                 memcpy(walk.dst.virt.addr, buf, nbytes);
605                 crypto_inc(walk.iv, AES_BLOCK_SIZE);
606                 ret = skcipher_walk_done(&walk, 0);
607         }
608
609         return ret;
610 }
611
612 static struct skcipher_alg ctr_aes_alg = {
613         .base.cra_name          =       "ctr(aes)",
614         .base.cra_driver_name   =       "ctr-aes-s390",
615         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
616         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
617         .base.cra_blocksize     =       1,
618         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
619         .base.cra_module        =       THIS_MODULE,
620         .init                   =       fallback_init_skcipher,
621         .exit                   =       fallback_exit_skcipher,
622         .min_keysize            =       AES_MIN_KEY_SIZE,
623         .max_keysize            =       AES_MAX_KEY_SIZE,
624         .ivsize                 =       AES_BLOCK_SIZE,
625         .setkey                 =       ctr_aes_set_key,
626         .encrypt                =       ctr_aes_crypt,
627         .decrypt                =       ctr_aes_crypt,
628         .chunksize              =       AES_BLOCK_SIZE,
629 };
630
631 static int gcm_aes_setkey(struct crypto_aead *tfm, const u8 *key,
632                           unsigned int keylen)
633 {
634         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
635
636         switch (keylen) {
637         case AES_KEYSIZE_128:
638                 ctx->fc = CPACF_KMA_GCM_AES_128;
639                 break;
640         case AES_KEYSIZE_192:
641                 ctx->fc = CPACF_KMA_GCM_AES_192;
642                 break;
643         case AES_KEYSIZE_256:
644                 ctx->fc = CPACF_KMA_GCM_AES_256;
645                 break;
646         default:
647                 return -EINVAL;
648         }
649
650         memcpy(ctx->key, key, keylen);
651         ctx->key_len = keylen;
652         return 0;
653 }
654
655 static int gcm_aes_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
656 {
657         switch (authsize) {
658         case 4:
659         case 8:
660         case 12:
661         case 13:
662         case 14:
663         case 15:
664         case 16:
665                 break;
666         default:
667                 return -EINVAL;
668         }
669
670         return 0;
671 }
672
673 static void gcm_walk_start(struct gcm_sg_walk *gw, struct scatterlist *sg,
674                            unsigned int len)
675 {
676         memset(gw, 0, sizeof(*gw));
677         gw->walk_bytes_remain = len;
678         scatterwalk_start(&gw->walk, sg);
679 }
680
681 static inline unsigned int _gcm_sg_clamp_and_map(struct gcm_sg_walk *gw)
682 {
683         struct scatterlist *nextsg;
684
685         gw->walk_bytes = scatterwalk_clamp(&gw->walk, gw->walk_bytes_remain);
686         while (!gw->walk_bytes) {
687                 nextsg = sg_next(gw->walk.sg);
688                 if (!nextsg)
689                         return 0;
690                 scatterwalk_start(&gw->walk, nextsg);
691                 gw->walk_bytes = scatterwalk_clamp(&gw->walk,
692                                                    gw->walk_bytes_remain);
693         }
694         gw->walk_ptr = scatterwalk_map(&gw->walk);
695         return gw->walk_bytes;
696 }
697
698 static inline void _gcm_sg_unmap_and_advance(struct gcm_sg_walk *gw,
699                                              unsigned int nbytes)
700 {
701         gw->walk_bytes_remain -= nbytes;
702         scatterwalk_unmap(gw->walk_ptr);
703         scatterwalk_advance(&gw->walk, nbytes);
704         scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
705         gw->walk_ptr = NULL;
706 }
707
708 static int gcm_in_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
709 {
710         int n;
711
712         if (gw->buf_bytes && gw->buf_bytes >= minbytesneeded) {
713                 gw->ptr = gw->buf;
714                 gw->nbytes = gw->buf_bytes;
715                 goto out;
716         }
717
718         if (gw->walk_bytes_remain == 0) {
719                 gw->ptr = NULL;
720                 gw->nbytes = 0;
721                 goto out;
722         }
723
724         if (!_gcm_sg_clamp_and_map(gw)) {
725                 gw->ptr = NULL;
726                 gw->nbytes = 0;
727                 goto out;
728         }
729
730         if (!gw->buf_bytes && gw->walk_bytes >= minbytesneeded) {
731                 gw->ptr = gw->walk_ptr;
732                 gw->nbytes = gw->walk_bytes;
733                 goto out;
734         }
735
736         while (1) {
737                 n = min(gw->walk_bytes, AES_BLOCK_SIZE - gw->buf_bytes);
738                 memcpy(gw->buf + gw->buf_bytes, gw->walk_ptr, n);
739                 gw->buf_bytes += n;
740                 _gcm_sg_unmap_and_advance(gw, n);
741                 if (gw->buf_bytes >= minbytesneeded) {
742                         gw->ptr = gw->buf;
743                         gw->nbytes = gw->buf_bytes;
744                         goto out;
745                 }
746                 if (!_gcm_sg_clamp_and_map(gw)) {
747                         gw->ptr = NULL;
748                         gw->nbytes = 0;
749                         goto out;
750                 }
751         }
752
753 out:
754         return gw->nbytes;
755 }
756
757 static int gcm_out_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
758 {
759         if (gw->walk_bytes_remain == 0) {
760                 gw->ptr = NULL;
761                 gw->nbytes = 0;
762                 goto out;
763         }
764
765         if (!_gcm_sg_clamp_and_map(gw)) {
766                 gw->ptr = NULL;
767                 gw->nbytes = 0;
768                 goto out;
769         }
770
771         if (gw->walk_bytes >= minbytesneeded) {
772                 gw->ptr = gw->walk_ptr;
773                 gw->nbytes = gw->walk_bytes;
774                 goto out;
775         }
776
777         scatterwalk_unmap(gw->walk_ptr);
778         gw->walk_ptr = NULL;
779
780         gw->ptr = gw->buf;
781         gw->nbytes = sizeof(gw->buf);
782
783 out:
784         return gw->nbytes;
785 }
786
787 static int gcm_in_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
788 {
789         if (gw->ptr == NULL)
790                 return 0;
791
792         if (gw->ptr == gw->buf) {
793                 int n = gw->buf_bytes - bytesdone;
794                 if (n > 0) {
795                         memmove(gw->buf, gw->buf + bytesdone, n);
796                         gw->buf_bytes = n;
797                 } else
798                         gw->buf_bytes = 0;
799         } else
800                 _gcm_sg_unmap_and_advance(gw, bytesdone);
801
802         return bytesdone;
803 }
804
805 static int gcm_out_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
806 {
807         int i, n;
808
809         if (gw->ptr == NULL)
810                 return 0;
811
812         if (gw->ptr == gw->buf) {
813                 for (i = 0; i < bytesdone; i += n) {
814                         if (!_gcm_sg_clamp_and_map(gw))
815                                 return i;
816                         n = min(gw->walk_bytes, bytesdone - i);
817                         memcpy(gw->walk_ptr, gw->buf + i, n);
818                         _gcm_sg_unmap_and_advance(gw, n);
819                 }
820         } else
821                 _gcm_sg_unmap_and_advance(gw, bytesdone);
822
823         return bytesdone;
824 }
825
826 static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
827 {
828         struct crypto_aead *tfm = crypto_aead_reqtfm(req);
829         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
830         unsigned int ivsize = crypto_aead_ivsize(tfm);
831         unsigned int taglen = crypto_aead_authsize(tfm);
832         unsigned int aadlen = req->assoclen;
833         unsigned int pclen = req->cryptlen;
834         int ret = 0;
835
836         unsigned int n, len, in_bytes, out_bytes,
837                      min_bytes, bytes, aad_bytes, pc_bytes;
838         struct gcm_sg_walk gw_in, gw_out;
839         u8 tag[GHASH_DIGEST_SIZE];
840
841         struct {
842                 u32 _[3];               /* reserved */
843                 u32 cv;                 /* Counter Value */
844                 u8 t[GHASH_DIGEST_SIZE];/* Tag */
845                 u8 h[AES_BLOCK_SIZE];   /* Hash-subkey */
846                 u64 taadl;              /* Total AAD Length */
847                 u64 tpcl;               /* Total Plain-/Cipher-text Length */
848                 u8 j0[GHASH_BLOCK_SIZE];/* initial counter value */
849                 u8 k[AES_MAX_KEY_SIZE]; /* Key */
850         } param;
851
852         /*
853          * encrypt
854          *   req->src: aad||plaintext
855          *   req->dst: aad||ciphertext||tag
856          * decrypt
857          *   req->src: aad||ciphertext||tag
858          *   req->dst: aad||plaintext, return 0 or -EBADMSG
859          * aad, plaintext and ciphertext may be empty.
860          */
861         if (flags & CPACF_DECRYPT)
862                 pclen -= taglen;
863         len = aadlen + pclen;
864
865         memset(&param, 0, sizeof(param));
866         param.cv = 1;
867         param.taadl = aadlen * 8;
868         param.tpcl = pclen * 8;
869         memcpy(param.j0, req->iv, ivsize);
870         *(u32 *)(param.j0 + ivsize) = 1;
871         memcpy(param.k, ctx->key, ctx->key_len);
872
873         gcm_walk_start(&gw_in, req->src, len);
874         gcm_walk_start(&gw_out, req->dst, len);
875
876         do {
877                 min_bytes = min_t(unsigned int,
878                                   aadlen > 0 ? aadlen : pclen, AES_BLOCK_SIZE);
879                 in_bytes = gcm_in_walk_go(&gw_in, min_bytes);
880                 out_bytes = gcm_out_walk_go(&gw_out, min_bytes);
881                 bytes = min(in_bytes, out_bytes);
882
883                 if (aadlen + pclen <= bytes) {
884                         aad_bytes = aadlen;
885                         pc_bytes = pclen;
886                         flags |= CPACF_KMA_LAAD | CPACF_KMA_LPC;
887                 } else {
888                         if (aadlen <= bytes) {
889                                 aad_bytes = aadlen;
890                                 pc_bytes = (bytes - aadlen) &
891                                            ~(AES_BLOCK_SIZE - 1);
892                                 flags |= CPACF_KMA_LAAD;
893                         } else {
894                                 aad_bytes = bytes & ~(AES_BLOCK_SIZE - 1);
895                                 pc_bytes = 0;
896                         }
897                 }
898
899                 if (aad_bytes > 0)
900                         memcpy(gw_out.ptr, gw_in.ptr, aad_bytes);
901
902                 cpacf_kma(ctx->fc | flags, &param,
903                           gw_out.ptr + aad_bytes,
904                           gw_in.ptr + aad_bytes, pc_bytes,
905                           gw_in.ptr, aad_bytes);
906
907                 n = aad_bytes + pc_bytes;
908                 if (gcm_in_walk_done(&gw_in, n) != n)
909                         return -ENOMEM;
910                 if (gcm_out_walk_done(&gw_out, n) != n)
911                         return -ENOMEM;
912                 aadlen -= aad_bytes;
913                 pclen -= pc_bytes;
914         } while (aadlen + pclen > 0);
915
916         if (flags & CPACF_DECRYPT) {
917                 scatterwalk_map_and_copy(tag, req->src, len, taglen, 0);
918                 if (crypto_memneq(tag, param.t, taglen))
919                         ret = -EBADMSG;
920         } else
921                 scatterwalk_map_and_copy(param.t, req->dst, len, taglen, 1);
922
923         memzero_explicit(&param, sizeof(param));
924         return ret;
925 }
926
927 static int gcm_aes_encrypt(struct aead_request *req)
928 {
929         return gcm_aes_crypt(req, CPACF_ENCRYPT);
930 }
931
932 static int gcm_aes_decrypt(struct aead_request *req)
933 {
934         return gcm_aes_crypt(req, CPACF_DECRYPT);
935 }
936
937 static struct aead_alg gcm_aes_aead = {
938         .setkey                 = gcm_aes_setkey,
939         .setauthsize            = gcm_aes_setauthsize,
940         .encrypt                = gcm_aes_encrypt,
941         .decrypt                = gcm_aes_decrypt,
942
943         .ivsize                 = GHASH_BLOCK_SIZE - sizeof(u32),
944         .maxauthsize            = GHASH_DIGEST_SIZE,
945         .chunksize              = AES_BLOCK_SIZE,
946
947         .base                   = {
948                 .cra_blocksize          = 1,
949                 .cra_ctxsize            = sizeof(struct s390_aes_ctx),
950                 .cra_priority           = 900,
951                 .cra_name               = "gcm(aes)",
952                 .cra_driver_name        = "gcm-aes-s390",
953                 .cra_module             = THIS_MODULE,
954         },
955 };
956
957 static struct crypto_alg *aes_s390_alg;
958 static struct skcipher_alg *aes_s390_skcipher_algs[4];
959 static int aes_s390_skciphers_num;
960 static struct aead_alg *aes_s390_aead_alg;
961
962 static int aes_s390_register_skcipher(struct skcipher_alg *alg)
963 {
964         int ret;
965
966         ret = crypto_register_skcipher(alg);
967         if (!ret)
968                 aes_s390_skcipher_algs[aes_s390_skciphers_num++] = alg;
969         return ret;
970 }
971
972 static void aes_s390_fini(void)
973 {
974         if (aes_s390_alg)
975                 crypto_unregister_alg(aes_s390_alg);
976         while (aes_s390_skciphers_num--)
977                 crypto_unregister_skcipher(aes_s390_skcipher_algs[aes_s390_skciphers_num]);
978         if (ctrblk)
979                 free_page((unsigned long) ctrblk);
980
981         if (aes_s390_aead_alg)
982                 crypto_unregister_aead(aes_s390_aead_alg);
983 }
984
985 static int __init aes_s390_init(void)
986 {
987         int ret;
988
989         /* Query available functions for KM, KMC, KMCTR and KMA */
990         cpacf_query(CPACF_KM, &km_functions);
991         cpacf_query(CPACF_KMC, &kmc_functions);
992         cpacf_query(CPACF_KMCTR, &kmctr_functions);
993         cpacf_query(CPACF_KMA, &kma_functions);
994
995         if (cpacf_test_func(&km_functions, CPACF_KM_AES_128) ||
996             cpacf_test_func(&km_functions, CPACF_KM_AES_192) ||
997             cpacf_test_func(&km_functions, CPACF_KM_AES_256)) {
998                 ret = crypto_register_alg(&aes_alg);
999                 if (ret)
1000                         goto out_err;
1001                 aes_s390_alg = &aes_alg;
1002                 ret = aes_s390_register_skcipher(&ecb_aes_alg);
1003                 if (ret)
1004                         goto out_err;
1005         }
1006
1007         if (cpacf_test_func(&kmc_functions, CPACF_KMC_AES_128) ||
1008             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_192) ||
1009             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_256)) {
1010                 ret = aes_s390_register_skcipher(&cbc_aes_alg);
1011                 if (ret)
1012                         goto out_err;
1013         }
1014
1015         if (cpacf_test_func(&km_functions, CPACF_KM_XTS_128) ||
1016             cpacf_test_func(&km_functions, CPACF_KM_XTS_256)) {
1017                 ret = aes_s390_register_skcipher(&xts_aes_alg);
1018                 if (ret)
1019                         goto out_err;
1020         }
1021
1022         if (cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_128) ||
1023             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_192) ||
1024             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_256)) {
1025                 ctrblk = (u8 *) __get_free_page(GFP_KERNEL);
1026                 if (!ctrblk) {
1027                         ret = -ENOMEM;
1028                         goto out_err;
1029                 }
1030                 ret = aes_s390_register_skcipher(&ctr_aes_alg);
1031                 if (ret)
1032                         goto out_err;
1033         }
1034
1035         if (cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_128) ||
1036             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_192) ||
1037             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_256)) {
1038                 ret = crypto_register_aead(&gcm_aes_aead);
1039                 if (ret)
1040                         goto out_err;
1041                 aes_s390_aead_alg = &gcm_aes_aead;
1042         }
1043
1044         return 0;
1045 out_err:
1046         aes_s390_fini();
1047         return ret;
1048 }
1049
1050 module_cpu_feature_match(S390_CPU_FEATURE_MSA, aes_s390_init);
1051 module_exit(aes_s390_fini);
1052
1053 MODULE_ALIAS_CRYPTO("aes-all");
1054
1055 MODULE_DESCRIPTION("Rijndael (AES) Cipher Algorithm");
1056 MODULE_LICENSE("GPL");
1057 MODULE_IMPORT_NS(CRYPTO_INTERNAL);