Merge branch 'linus' of git://git.kernel.org/pub/scm/linux/kernel/git/herbert/crypto-2.6
[sfrench/cifs-2.6.git] / crypto / cbc.c
index e6f6273a7d3990589e2d6917a100e0204e7ad1ee..0d9509dff891d54439152c7378b3ac9983098104 100644 (file)
@@ -6,7 +6,6 @@
  */
 
 #include <crypto/algapi.h>
-#include <crypto/cbc.h>
 #include <crypto/internal/skcipher.h>
 #include <linux/err.h>
 #include <linux/init.h>
 #include <linux/log2.h>
 #include <linux/module.h>
 
-static inline void crypto_cbc_encrypt_one(struct crypto_skcipher *tfm,
-                                         const u8 *src, u8 *dst)
+static int crypto_cbc_encrypt_segment(struct skcipher_walk *walk,
+                                     struct crypto_skcipher *skcipher)
 {
-       crypto_cipher_encrypt_one(skcipher_cipher_simple(tfm), dst, src);
+       unsigned int bsize = crypto_skcipher_blocksize(skcipher);
+       void (*fn)(struct crypto_tfm *, u8 *, const u8 *);
+       unsigned int nbytes = walk->nbytes;
+       u8 *src = walk->src.virt.addr;
+       u8 *dst = walk->dst.virt.addr;
+       struct crypto_cipher *cipher;
+       struct crypto_tfm *tfm;
+       u8 *iv = walk->iv;
+
+       cipher = skcipher_cipher_simple(skcipher);
+       tfm = crypto_cipher_tfm(cipher);
+       fn = crypto_cipher_alg(cipher)->cia_encrypt;
+
+       do {
+               crypto_xor(iv, src, bsize);
+               fn(tfm, dst, iv);
+               memcpy(iv, dst, bsize);
+
+               src += bsize;
+               dst += bsize;
+       } while ((nbytes -= bsize) >= bsize);
+
+       return nbytes;
+}
+
+static int crypto_cbc_encrypt_inplace(struct skcipher_walk *walk,
+                                     struct crypto_skcipher *skcipher)
+{
+       unsigned int bsize = crypto_skcipher_blocksize(skcipher);
+       void (*fn)(struct crypto_tfm *, u8 *, const u8 *);
+       unsigned int nbytes = walk->nbytes;
+       u8 *src = walk->src.virt.addr;
+       struct crypto_cipher *cipher;
+       struct crypto_tfm *tfm;
+       u8 *iv = walk->iv;
+
+       cipher = skcipher_cipher_simple(skcipher);
+       tfm = crypto_cipher_tfm(cipher);
+       fn = crypto_cipher_alg(cipher)->cia_encrypt;
+
+       do {
+               crypto_xor(src, iv, bsize);
+               fn(tfm, src, src);
+               iv = src;
+
+               src += bsize;
+       } while ((nbytes -= bsize) >= bsize);
+
+       memcpy(walk->iv, iv, bsize);
+
+       return nbytes;
 }
 
 static int crypto_cbc_encrypt(struct skcipher_request *req)
 {
-       return crypto_cbc_encrypt_walk(req, crypto_cbc_encrypt_one);
+       struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
+       struct skcipher_walk walk;
+       int err;
+
+       err = skcipher_walk_virt(&walk, req, false);
+
+       while (walk.nbytes) {
+               if (walk.src.virt.addr == walk.dst.virt.addr)
+                       err = crypto_cbc_encrypt_inplace(&walk, skcipher);
+               else
+                       err = crypto_cbc_encrypt_segment(&walk, skcipher);
+               err = skcipher_walk_done(&walk, err);
+       }
+
+       return err;
+}
+
+static int crypto_cbc_decrypt_segment(struct skcipher_walk *walk,
+                                     struct crypto_skcipher *skcipher)
+{
+       unsigned int bsize = crypto_skcipher_blocksize(skcipher);
+       void (*fn)(struct crypto_tfm *, u8 *, const u8 *);
+       unsigned int nbytes = walk->nbytes;
+       u8 *src = walk->src.virt.addr;
+       u8 *dst = walk->dst.virt.addr;
+       struct crypto_cipher *cipher;
+       struct crypto_tfm *tfm;
+       u8 *iv = walk->iv;
+
+       cipher = skcipher_cipher_simple(skcipher);
+       tfm = crypto_cipher_tfm(cipher);
+       fn = crypto_cipher_alg(cipher)->cia_decrypt;
+
+       do {
+               fn(tfm, dst, src);
+               crypto_xor(dst, iv, bsize);
+               iv = src;
+
+               src += bsize;
+               dst += bsize;
+       } while ((nbytes -= bsize) >= bsize);
+
+       memcpy(walk->iv, iv, bsize);
+
+       return nbytes;
 }
 
-static inline void crypto_cbc_decrypt_one(struct crypto_skcipher *tfm,
-                                         const u8 *src, u8 *dst)
+static int crypto_cbc_decrypt_inplace(struct skcipher_walk *walk,
+                                     struct crypto_skcipher *skcipher)
 {
-       crypto_cipher_decrypt_one(skcipher_cipher_simple(tfm), dst, src);
+       unsigned int bsize = crypto_skcipher_blocksize(skcipher);
+       void (*fn)(struct crypto_tfm *, u8 *, const u8 *);
+       unsigned int nbytes = walk->nbytes;
+       u8 *src = walk->src.virt.addr;
+       u8 last_iv[MAX_CIPHER_BLOCKSIZE];
+       struct crypto_cipher *cipher;
+       struct crypto_tfm *tfm;
+
+       cipher = skcipher_cipher_simple(skcipher);
+       tfm = crypto_cipher_tfm(cipher);
+       fn = crypto_cipher_alg(cipher)->cia_decrypt;
+
+       /* Start of the last block. */
+       src += nbytes - (nbytes & (bsize - 1)) - bsize;
+       memcpy(last_iv, src, bsize);
+
+       for (;;) {
+               fn(tfm, src, src);
+               if ((nbytes -= bsize) < bsize)
+                       break;
+               crypto_xor(src, src - bsize, bsize);
+               src -= bsize;
+       }
+
+       crypto_xor(src, walk->iv, bsize);
+       memcpy(walk->iv, last_iv, bsize);
+
+       return nbytes;
 }
 
 static int crypto_cbc_decrypt(struct skcipher_request *req)
 {
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+       struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
        struct skcipher_walk walk;
        int err;
 
        err = skcipher_walk_virt(&walk, req, false);
 
        while (walk.nbytes) {
-               err = crypto_cbc_decrypt_blocks(&walk, tfm,
-                                               crypto_cbc_decrypt_one);
+               if (walk.src.virt.addr == walk.dst.virt.addr)
+                       err = crypto_cbc_decrypt_inplace(&walk, skcipher);
+               else
+                       err = crypto_cbc_decrypt_segment(&walk, skcipher);
                err = skcipher_walk_done(&walk, err);
        }