Merge tag 'media/v4.10-4' of git://git.kernel.org/pub/scm/linux/kernel/git/mchehab...
[sfrench/cifs-2.6.git] / include / crypto / cbc.h
1 /*
2  * CBC: Cipher Block Chaining mode
3  *
4  * Copyright (c) 2016 Herbert Xu <herbert@gondor.apana.org.au>
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by the Free
8  * Software Foundation; either version 2 of the License, or (at your option)
9  * any later version.
10  *
11  */
12
13 #ifndef _CRYPTO_CBC_H
14 #define _CRYPTO_CBC_H
15
16 #include <crypto/internal/skcipher.h>
17 #include <linux/string.h>
18 #include <linux/types.h>
19
20 static inline int crypto_cbc_encrypt_segment(
21         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
22         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
23 {
24         unsigned int bsize = crypto_skcipher_blocksize(tfm);
25         unsigned int nbytes = walk->nbytes;
26         u8 *src = walk->src.virt.addr;
27         u8 *dst = walk->dst.virt.addr;
28         u8 *iv = walk->iv;
29
30         do {
31                 crypto_xor(iv, src, bsize);
32                 fn(tfm, iv, dst);
33                 memcpy(iv, dst, bsize);
34
35                 src += bsize;
36                 dst += bsize;
37         } while ((nbytes -= bsize) >= bsize);
38
39         return nbytes;
40 }
41
42 static inline int crypto_cbc_encrypt_inplace(
43         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
44         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
45 {
46         unsigned int bsize = crypto_skcipher_blocksize(tfm);
47         unsigned int nbytes = walk->nbytes;
48         u8 *src = walk->src.virt.addr;
49         u8 *iv = walk->iv;
50
51         do {
52                 crypto_xor(src, iv, bsize);
53                 fn(tfm, src, src);
54                 iv = src;
55
56                 src += bsize;
57         } while ((nbytes -= bsize) >= bsize);
58
59         memcpy(walk->iv, iv, bsize);
60
61         return nbytes;
62 }
63
64 static inline int crypto_cbc_encrypt_walk(struct skcipher_request *req,
65                                           void (*fn)(struct crypto_skcipher *,
66                                                      const u8 *, u8 *))
67 {
68         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
69         struct skcipher_walk walk;
70         int err;
71
72         err = skcipher_walk_virt(&walk, req, false);
73
74         while (walk.nbytes) {
75                 if (walk.src.virt.addr == walk.dst.virt.addr)
76                         err = crypto_cbc_encrypt_inplace(&walk, tfm, fn);
77                 else
78                         err = crypto_cbc_encrypt_segment(&walk, tfm, fn);
79                 err = skcipher_walk_done(&walk, err);
80         }
81
82         return err;
83 }
84
85 static inline int crypto_cbc_decrypt_segment(
86         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
87         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
88 {
89         unsigned int bsize = crypto_skcipher_blocksize(tfm);
90         unsigned int nbytes = walk->nbytes;
91         u8 *src = walk->src.virt.addr;
92         u8 *dst = walk->dst.virt.addr;
93         u8 *iv = walk->iv;
94
95         do {
96                 fn(tfm, src, dst);
97                 crypto_xor(dst, iv, bsize);
98                 iv = src;
99
100                 src += bsize;
101                 dst += bsize;
102         } while ((nbytes -= bsize) >= bsize);
103
104         memcpy(walk->iv, iv, bsize);
105
106         return nbytes;
107 }
108
109 static inline int crypto_cbc_decrypt_inplace(
110         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
111         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
112 {
113         unsigned int bsize = crypto_skcipher_blocksize(tfm);
114         unsigned int nbytes = walk->nbytes;
115         u8 *src = walk->src.virt.addr;
116         u8 last_iv[bsize];
117
118         /* Start of the last block. */
119         src += nbytes - (nbytes & (bsize - 1)) - bsize;
120         memcpy(last_iv, src, bsize);
121
122         for (;;) {
123                 fn(tfm, src, src);
124                 if ((nbytes -= bsize) < bsize)
125                         break;
126                 crypto_xor(src, src - bsize, bsize);
127                 src -= bsize;
128         }
129
130         crypto_xor(src, walk->iv, bsize);
131         memcpy(walk->iv, last_iv, bsize);
132
133         return nbytes;
134 }
135
136 static inline int crypto_cbc_decrypt_blocks(
137         struct skcipher_walk *walk, struct crypto_skcipher *tfm,
138         void (*fn)(struct crypto_skcipher *, const u8 *, u8 *))
139 {
140         if (walk->src.virt.addr == walk->dst.virt.addr)
141                 return crypto_cbc_decrypt_inplace(walk, tfm, fn);
142         else
143                 return crypto_cbc_decrypt_segment(walk, tfm, fn);
144 }
145
146 #endif  /* _CRYPTO_CBC_H */