Merge tag 'nfsd-5.3' of git://linux-nfs.org/~bfields/linux
[sfrench/cifs-2.6.git] / include / crypto / cbc.h
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * CBC: Cipher Block Chaining mode
4  *
5  * Copyright (c) 2016 Herbert Xu <herbert@gondor.apana.org.au>
6  */
7
8 #ifndef _CRYPTO_CBC_H
9 #define _CRYPTO_CBC_H
10
11 #include <crypto/internal/skcipher.h>
12 #include <linux/string.h>
13 #include <linux/types.h>
14
15 static inline int crypto_cbc_encrypt_segment(
16         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
17         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
18 {
19         unsigned int bsize = crypto_skcipher_blocksize(tfm);
20         unsigned int nbytes = walk->nbytes;
21         u8 *src = walk->src.virt.addr;
22         u8 *dst = walk->dst.virt.addr;
23         u8 *iv = walk->iv;
24
25         do {
26                 crypto_xor(iv, src, bsize);
27                 fn(tfm, iv, dst);
28                 memcpy(iv, dst, bsize);
29
30                 src += bsize;
31                 dst += bsize;
32         } while ((nbytes -= bsize) >= bsize);
33
34         return nbytes;
35 }
36
37 static inline int crypto_cbc_encrypt_inplace(
38         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
39         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
40 {
41         unsigned int bsize = crypto_skcipher_blocksize(tfm);
42         unsigned int nbytes = walk->nbytes;
43         u8 *src = walk->src.virt.addr;
44         u8 *iv = walk->iv;
45
46         do {
47                 crypto_xor(src, iv, bsize);
48                 fn(tfm, src, src);
49                 iv = src;
50
51                 src += bsize;
52         } while ((nbytes -= bsize) >= bsize);
53
54         memcpy(walk->iv, iv, bsize);
55
56         return nbytes;
57 }
58
59 static inline int crypto_cbc_encrypt_walk(struct skcipher_request *req,
60                                           void (*fn)(struct crypto_skcipher *,
61                                                      const u8 *, u8 *))
62 {
63         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
64         struct skcipher_walk walk;
65         int err;
66
67         err = skcipher_walk_virt(&walk, req, false);
68
69         while (walk.nbytes) {
70                 if (walk.src.virt.addr == walk.dst.virt.addr)
71                         err = crypto_cbc_encrypt_inplace(&walk, tfm, fn);
72                 else
73                         err = crypto_cbc_encrypt_segment(&walk, tfm, fn);
74                 err = skcipher_walk_done(&walk, err);
75         }
76
77         return err;
78 }
79
80 static inline int crypto_cbc_decrypt_segment(
81         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
82         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
83 {
84         unsigned int bsize = crypto_skcipher_blocksize(tfm);
85         unsigned int nbytes = walk->nbytes;
86         u8 *src = walk->src.virt.addr;
87         u8 *dst = walk->dst.virt.addr;
88         u8 *iv = walk->iv;
89
90         do {
91                 fn(tfm, src, dst);
92                 crypto_xor(dst, iv, bsize);
93                 iv = src;
94
95                 src += bsize;
96                 dst += bsize;
97         } while ((nbytes -= bsize) >= bsize);
98
99         memcpy(walk->iv, iv, bsize);
100
101         return nbytes;
102 }
103
104 static inline int crypto_cbc_decrypt_inplace(
105         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
106         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
107 {
108         unsigned int bsize = crypto_skcipher_blocksize(tfm);
109         unsigned int nbytes = walk->nbytes;
110         u8 *src = walk->src.virt.addr;
111         u8 last_iv[MAX_CIPHER_BLOCKSIZE];
112
113         /* Start of the last block. */
114         src += nbytes - (nbytes & (bsize - 1)) - bsize;
115         memcpy(last_iv, src, bsize);
116
117         for (;;) {
118                 fn(tfm, src, src);
119                 if ((nbytes -= bsize) < bsize)
120                         break;
121                 crypto_xor(src, src - bsize, bsize);
122                 src -= bsize;
123         }
124
125         crypto_xor(src, walk->iv, bsize);
126         memcpy(walk->iv, last_iv, bsize);
127
128         return nbytes;
129 }
130
131 static inline int crypto_cbc_decrypt_blocks(
132         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
133         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
134 {
135         if (walk->src.virt.addr == walk->dst.virt.addr)
136                 return crypto_cbc_decrypt_inplace(walk, tfm, fn);
137         else
138                 return crypto_cbc_decrypt_segment(walk, tfm, fn);
139 }
140
141 #endif  /* _CRYPTO_CBC_H */