libndr: Avoid assigning duplicate versions to symbols
[amitay/samba.git] / python / samba / tests / krb5 / kcrypto.py
1 #!/usr/bin/env python3
2 #
3 # Copyright (C) 2013 by the Massachusetts Institute of Technology.
4 # All rights reserved.
5 #
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions
8 # are met:
9 #
10 # * Redistributions of source code must retain the above copyright
11 #   notice, this list of conditions and the following disclaimer.
12 #
13 # * Redistributions in binary form must reproduce the above copyright
14 #   notice, this list of conditions and the following disclaimer in
15 #   the documentation and/or other materials provided with the
16 #   distribution.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
21 # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
22 # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
23 # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26 # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
27 # STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
29 # OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 # XXX current status:
32 # * Done and tested
33 #   - AES encryption, checksum, string2key, prf
34 #   - cf2 (needed for FAST)
35 # * Still to do:
36 #   - DES enctypes and cksumtypes
37 #   - RC4 exported enctype (if we need it for anything)
38 #   - Unkeyed checksums
39 #   - Special RC4, raw DES/DES3 operations for GSSAPI
40 # * Difficult or low priority:
41 #   - Camellia not supported by PyCrypto
42 #   - Cipher state only needed for kcmd suite
43 #   - Nonstandard enctypes and cksumtypes like des-hmac-sha1
44
45 import sys
46 import os
47
48 sys.path.insert(0, "bin/python")
49 os.environ["PYTHONUNBUFFERED"] = "1"
50
51 from math import gcd
52 from functools import reduce
53 from struct import pack, unpack
54 from binascii import crc32
55 from cryptography.hazmat.primitives import hashes
56 from cryptography.hazmat.primitives import hmac
57 from cryptography.hazmat.primitives.ciphers import algorithms as ciphers
58 from cryptography.hazmat.primitives.ciphers import modes
59 from cryptography.hazmat.primitives.ciphers.base import Cipher
60 from cryptography.hazmat.backends import default_backend
61 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
62 from samba.tests import TestCase
63 from samba.credentials import Credentials
64 from samba import generate_random_bytes as get_random_bytes
65 from samba.common import get_string, get_bytes
66
67 class Enctype(object):
68     DES_CRC = 1
69     DES_MD4 = 2
70     DES_MD5 = 3
71     DES3 = 16
72     AES128 = 17
73     AES256 = 18
74     RC4 = 23
75
76
77 class Cksumtype(object):
78     CRC32 = 1
79     MD4 = 2
80     MD4_DES = 3
81     MD5 = 7
82     MD5_DES = 8
83     SHA1 = 9
84     SHA1_DES3 = 12
85     SHA1_AES128 = 15
86     SHA1_AES256 = 16
87     HMAC_MD5 = -138
88
89
90 class InvalidChecksum(ValueError):
91     pass
92
93
94 def _zeropad(s, padsize):
95     # Return s padded with 0 bytes to a multiple of padsize.
96     padlen = (padsize - (len(s) % padsize)) % padsize
97     return s + bytes(padlen)
98
99
100 def _xorbytes(b1, b2):
101     # xor two strings together and return the resulting string.
102     assert len(b1) == len(b2)
103     return bytes([x ^ y for x, y in zip(b1, b2)])
104
105
106 def _mac_equal(mac1, mac2):
107     # Constant-time comparison function.  (We can't use HMAC.verify
108     # since we use truncated macs.)
109     assert len(mac1) == len(mac2)
110     res = 0
111     for x, y in zip(mac1, mac2):
112         res |= x ^ y
113     return res == 0
114
115 def SIMPLE_HASH(string, algo_cls):
116     hash_ctx = hashes.Hash(algo_cls(), default_backend())
117     hash_ctx.update(string)
118     return hash_ctx.finalize()
119
120 def HMAC_HASH(key, string, algo_cls):
121     hmac_ctx = hmac.HMAC(key, algo_cls(), default_backend())
122     hmac_ctx.update(string)
123     return hmac_ctx.finalize()
124
125 def _nfold(str, nbytes):
126     # Convert str to a string of length nbytes using the RFC 3961 nfold
127     # operation.
128
129     # Rotate the bytes in str to the right by nbits bits.
130     def rotate_right(str, nbits):
131         nbytes, remain = (nbits//8) % len(str), nbits % 8
132         return bytes([(str[i-nbytes] >> remain) |
133                       (str[i-nbytes-1] << (8-remain) & 0xff)
134                       for i in range(len(str))])
135
136     # Add equal-length strings together with end-around carry.
137     def add_ones_complement(str1, str2):
138         n = len(str1)
139         v = [a + b for a, b in zip(str1, str2)]
140         # Propagate carry bits to the left until there aren't any left.
141         while any(x & ~0xff for x in v):
142             v = [(v[i-n+1]>>8) + (v[i]&0xff) for i in range(n)]
143         return bytes([x for x in v])
144
145     # Concatenate copies of str to produce the least common multiple
146     # of len(str) and nbytes, rotating each copy of str to the right
147     # by 13 bits times its list position.  Decompose the concatenation
148     # into slices of length nbytes, and add them together as
149     # big-endian ones' complement integers.
150     slen = len(str)
151     lcm = nbytes * slen // gcd(nbytes, slen)
152     bigstr = b''.join((rotate_right(str, 13 * i) for i in range(lcm // slen)))
153     slices = (bigstr[p:p+nbytes] for p in range(0, lcm, nbytes))
154     return reduce(add_ones_complement, slices)
155
156
157 def _is_weak_des_key(keybytes):
158     return keybytes in (b'\x01\x01\x01\x01\x01\x01\x01\x01',
159                         b'\xFE\xFE\xFE\xFE\xFE\xFE\xFE\xFE',
160                         b'\x1F\x1F\x1F\x1F\x0E\x0E\x0E\x0E',
161                         b'\xE0\xE0\xE0\xE0\xF1\xF1\xF1\xF1',
162                         b'\x01\xFE\x01\xFE\x01\xFE\x01\xFE',
163                         b'\xFE\x01\xFE\x01\xFE\x01\xFE\x01',
164                         b'\x1F\xE0\x1F\xE0\x0E\xF1\x0E\xF1',
165                         b'\xE0\x1F\xE0\x1F\xF1\x0E\xF1\x0E',
166                         b'\x01\xE0\x01\xE0\x01\xF1\x01\xF1',
167                         b'\xE0\x01\xE0\x01\xF1\x01\xF1\x01',
168                         b'\x1F\xFE\x1F\xFE\x0E\xFE\x0E\xFE',
169                         b'\xFE\x1F\xFE\x1F\xFE\x0E\xFE\x0E',
170                         b'\x01\x1F\x01\x1F\x01\x0E\x01\x0E',
171                         b'\x1F\x01\x1F\x01\x0E\x01\x0E\x01',
172                         b'\xE0\xFE\xE0\xFE\xF1\xFE\xF1\xFE',
173                         b'\xFE\xE0\xFE\xE0\xFE\xF1\xFE\xF1')
174
175
176 class _EnctypeProfile(object):
177     # Base class for enctype profiles.  Usable enctype classes must define:
178     #   * enctype: enctype number
179     #   * keysize: protocol size of key in bytes
180     #   * seedsize: random_to_key input size in bytes
181     #   * random_to_key (if the keyspace is not dense)
182     #   * string_to_key
183     #   * encrypt
184     #   * decrypt
185     #   * prf
186
187     @classmethod
188     def random_to_key(cls, seed):
189         if len(seed) != cls.seedsize:
190             raise ValueError('Wrong seed length')
191         return Key(cls.enctype, seed)
192
193
194 class _SimplifiedEnctype(_EnctypeProfile):
195     # Base class for enctypes using the RFC 3961 simplified profile.
196     # Defines the encrypt, decrypt, and prf methods.  Subclasses must
197     # define:
198     #   * blocksize: Underlying cipher block size in bytes
199     #   * padsize: Underlying cipher padding multiple (1 or blocksize)
200     #   * macsize: Size of integrity MAC in bytes
201     #   * hashmod: PyCrypto hash module for underlying hash function
202     #   * basic_encrypt, basic_decrypt: Underlying CBC/CTS cipher
203
204     @classmethod
205     def derive(cls, key, constant):
206         # RFC 3961 only says to n-fold the constant only if it is
207         # shorter than the cipher block size.  But all Unix
208         # implementations n-fold constants if their length is larger
209         # than the block size as well, and n-folding when the length
210         # is equal to the block size is a no-op.
211         plaintext = _nfold(constant, cls.blocksize)
212         rndseed = b''
213         while len(rndseed) < cls.seedsize:
214             ciphertext = cls.basic_encrypt(key, plaintext)
215             rndseed += ciphertext
216             plaintext = ciphertext
217         return cls.random_to_key(rndseed[0:cls.seedsize])
218
219     @classmethod
220     def encrypt(cls, key, keyusage, plaintext, confounder):
221         ki = cls.derive(key, pack('>iB', keyusage, 0x55))
222         ke = cls.derive(key, pack('>iB', keyusage, 0xAA))
223         if confounder is None:
224             confounder = get_random_bytes(cls.blocksize)
225         basic_plaintext = confounder + _zeropad(plaintext, cls.padsize)
226         hmac = HMAC_HASH(ki.contents, basic_plaintext, cls.hashalgo)
227         return cls.basic_encrypt(ke, basic_plaintext) + hmac[:cls.macsize]
228
229     @classmethod
230     def decrypt(cls, key, keyusage, ciphertext):
231         ki = cls.derive(key, pack('>iB', keyusage, 0x55))
232         ke = cls.derive(key, pack('>iB', keyusage, 0xAA))
233         if len(ciphertext) < cls.blocksize + cls.macsize:
234             raise ValueError('ciphertext too short')
235         basic_ctext, mac = ciphertext[:-cls.macsize], ciphertext[-cls.macsize:]
236         if len(basic_ctext) % cls.padsize != 0:
237             raise ValueError('ciphertext does not meet padding requirement')
238         basic_plaintext = cls.basic_decrypt(ke, basic_ctext)
239         hmac = HMAC_HASH(ki.contents, basic_plaintext, cls.hashalgo)
240         expmac = hmac[:cls.macsize]
241         if not _mac_equal(mac, expmac):
242             raise InvalidChecksum('ciphertext integrity failure')
243         # Discard the confounder.
244         return basic_plaintext[cls.blocksize:]
245
246     @classmethod
247     def prf(cls, key, string):
248         # Hash the input.  RFC 3961 says to truncate to the padding
249         # size, but implementations truncate to the block size.
250         hashval = SIMPLE_HASH(string, cls.hashalgo)
251         truncated = hashval[:-(len(hashval) % cls.blocksize)]
252         # Encrypt the hash with a derived key.
253         kp = cls.derive(key, b'prf')
254         return cls.basic_encrypt(kp, truncated)
255
256
257 class _DES3CBC(_SimplifiedEnctype):
258     enctype = Enctype.DES3
259     keysize = 24
260     seedsize = 21
261     blocksize = 8
262     padsize = 8
263     macsize = 20
264     hashalgo = hashes.SHA1
265
266     @classmethod
267     def random_to_key(cls, seed):
268         # XXX Maybe reframe as _DESEnctype.random_to_key and use that
269         # way from DES3 random-to-key when DES is implemented, since
270         # MIT does this instead of the RFC 3961 random-to-key.
271         def expand(seed):
272             def parity(b):
273                 # Return b with the low-order bit set to yield odd parity.
274                 b &= ~1
275                 return b if bin(b & ~1).count('1') % 2 else b | 1
276             assert len(seed) == 7
277             firstbytes = [parity(b & ~1) for b in seed]
278             lastbyte = parity(sum((seed[i]&1) << i+1 for i in range(7)))
279             keybytes = bytes([b for b in firstbytes + [lastbyte]])
280             if _is_weak_des_key(keybytes):
281                 keybytes[7] = bytes([keybytes[7] ^ 0xF0])
282             return keybytes
283
284         if len(seed) != 21:
285             raise ValueError('Wrong seed length')
286         k1, k2, k3 = expand(seed[:7]), expand(seed[7:14]), expand(seed[14:])
287         return Key(cls.enctype, k1 + k2 + k3)
288
289     @classmethod
290     def string_to_key(cls, string, salt, params):
291         if params is not None and params != b'':
292             raise ValueError('Invalid DES3 string-to-key parameters')
293         k = cls.random_to_key(_nfold(string + salt, 21))
294         return cls.derive(k, b'kerberos')
295
296     @classmethod
297     def basic_encrypt(cls, key, plaintext):
298         assert len(plaintext) % 8 == 0
299         algo = ciphers.TripleDES(key.contents)
300         cbc = modes.CBC(bytes(8))
301         encryptor = Cipher(algo, cbc, default_backend()).encryptor()
302         ciphertext = encryptor.update(plaintext)
303         return ciphertext
304
305     @classmethod
306     def basic_decrypt(cls, key, ciphertext):
307         assert len(ciphertext) % 8 == 0
308         algo = ciphers.TripleDES(key.contents)
309         cbc = modes.CBC(bytes(8))
310         decryptor = Cipher(algo, cbc, default_backend()).decryptor()
311         plaintext = decryptor.update(ciphertext)
312         return plaintext
313
314
315 class _AESEnctype(_SimplifiedEnctype):
316     # Base class for aes128-cts and aes256-cts.
317     blocksize = 16
318     padsize = 1
319     macsize = 12
320     hashalgo = hashes.SHA1
321
322     @classmethod
323     def string_to_key(cls, string, salt, params):
324         (iterations,) = unpack('>L', params or b'\x00\x00\x10\x00')
325         pwbytes = get_bytes(string)
326         kdf = PBKDF2HMAC(algorithm=hashes.SHA1(),
327                          length=cls.seedsize,
328                          salt=salt,
329                          iterations=iterations,
330                          backend=default_backend())
331         seed = kdf.derive(pwbytes)
332         tkey = cls.random_to_key(seed)
333         return cls.derive(tkey, b'kerberos')
334
335     @classmethod
336     def basic_encrypt(cls, key, plaintext):
337         assert len(plaintext) >= 16
338
339         algo = ciphers.AES(key.contents)
340         cbc = modes.CBC(bytes(16))
341         aes_ctx = Cipher(algo, cbc, default_backend())
342
343         def aes_encrypt(plaintext):
344             encryptor = aes_ctx.encryptor()
345             ciphertext = encryptor.update(plaintext)
346             return ciphertext
347
348         ctext = aes_encrypt(_zeropad(plaintext, 16))
349         if len(plaintext) > 16:
350             # Swap the last two ciphertext blocks and truncate the
351             # final block to match the plaintext length.
352             lastlen = len(plaintext) % 16 or 16
353             ctext = ctext[:-32] + ctext[-16:] + ctext[-32:-16][:lastlen]
354         return ctext
355
356     @classmethod
357     def basic_decrypt(cls, key, ciphertext):
358         assert len(ciphertext) >= 16
359
360         algo = ciphers.AES(key.contents)
361         cbc = modes.CBC(bytes(16))
362         aes_ctx = Cipher(algo, cbc, default_backend())
363
364         def aes_decrypt(ciphertext):
365             decryptor = aes_ctx.decryptor()
366             plaintext = decryptor.update(ciphertext)
367             return plaintext
368
369         if len(ciphertext) == 16:
370             return aes_decrypt(ciphertext)
371         # Split the ciphertext into blocks.  The last block may be partial.
372         cblocks = [ciphertext[p:p+16] for p in range(0, len(ciphertext), 16)]
373         lastlen = len(cblocks[-1])
374         # CBC-decrypt all but the last two blocks.
375         prev_cblock = bytes(16)
376         plaintext = b''
377         for b in cblocks[:-2]:
378             plaintext += _xorbytes(aes_decrypt(b), prev_cblock)
379             prev_cblock = b
380         # Decrypt the second-to-last cipher block.  The left side of
381         # the decrypted block will be the final block of plaintext
382         # xor'd with the final partial cipher block; the right side
383         # will be the omitted bytes of ciphertext from the final
384         # block.
385         b = aes_decrypt(cblocks[-2])
386         lastplaintext =_xorbytes(b[:lastlen], cblocks[-1])
387         omitted = b[lastlen:]
388         # Decrypt the final cipher block plus the omitted bytes to get
389         # the second-to-last plaintext block.
390         plaintext += _xorbytes(aes_decrypt(cblocks[-1] + omitted), prev_cblock)
391         return plaintext + lastplaintext
392
393
394 class _AES128CTS(_AESEnctype):
395     enctype = Enctype.AES128
396     keysize = 16
397     seedsize = 16
398
399
400 class _AES256CTS(_AESEnctype):
401     enctype = Enctype.AES256
402     keysize = 32
403     seedsize = 32
404
405
406 class _RC4(_EnctypeProfile):
407     enctype = Enctype.RC4
408     keysize = 16
409     seedsize = 16
410
411     @staticmethod
412     def usage_str(keyusage):
413         # Return a four-byte string for an RFC 3961 keyusage, using
414         # the RFC 4757 rules.  Per the errata, do not map 9 to 8.
415         table = {3: 8, 23: 13}
416         msusage = table[keyusage] if keyusage in table else keyusage
417         return pack('<i', msusage)
418
419     @classmethod
420     def string_to_key(cls, string, salt, params):
421         utf8string = get_string(string)
422         tmp = Credentials()
423         tmp.set_anonymous()
424         tmp.set_password(utf8string)
425         nthash = tmp.get_nt_hash()
426         return Key(cls.enctype, nthash)
427
428     @classmethod
429     def encrypt(cls, key, keyusage, plaintext, confounder):
430         if confounder is None:
431             confounder = get_random_bytes(8)
432         ki = HMAC_HASH(key.contents, cls.usage_str(keyusage), hashes.MD5)
433         cksum = HMAC_HASH(ki, confounder + plaintext, hashes.MD5)
434         ke = HMAC_HASH(ki, cksum, hashes.MD5)
435
436         encryptor = Cipher(ciphers.ARC4(ke), None, default_backend()).encryptor()
437         ctext = encryptor.update(confounder + plaintext)
438
439         return cksum + ctext
440
441     @classmethod
442     def decrypt(cls, key, keyusage, ciphertext):
443         if len(ciphertext) < 24:
444             raise ValueError('ciphertext too short')
445         cksum, basic_ctext = ciphertext[:16], ciphertext[16:]
446         ki = HMAC_HASH(key.contents, cls.usage_str(keyusage), hashes.MD5)
447         ke = HMAC_HASH(ki, cksum, hashes.MD5)
448
449         decryptor = Cipher(ciphers.ARC4(ke), None, default_backend()).decryptor()
450         basic_plaintext = decryptor.update(basic_ctext)
451
452         exp_cksum = HMAC_HASH(ki, basic_plaintext, hashes.MD5)
453         ok = _mac_equal(cksum, exp_cksum)
454         if not ok and keyusage == 9:
455             # Try again with usage 8, due to RFC 4757 errata.
456             ki = HMAC_HASH(key.contents, pack('<i', 8), hashes.MD5)
457             exp_cksum = HMAC_HASH(ki, basic_plaintext, hashes.MD5)
458             ok = _mac_equal(cksum, exp_cksum)
459         if not ok:
460             raise InvalidChecksum('ciphertext integrity failure')
461         # Discard the confounder.
462         return basic_plaintext[8:]
463
464     @classmethod
465     def prf(cls, key, string):
466         return HMAC_HASH(key.contents, string, hashes.SHA1)
467
468
469 class _ChecksumProfile(object):
470     # Base class for checksum profiles.  Usable checksum classes must
471     # define:
472     #   * checksum
473     #   * verify (if verification is not just checksum-and-compare)
474     @classmethod
475     def verify(cls, key, keyusage, text, cksum):
476         expected = cls.checksum(key, keyusage, text)
477         if not _mac_equal(cksum, expected):
478             raise InvalidChecksum('checksum verification failure')
479
480
481 class _SimplifiedChecksum(_ChecksumProfile):
482     # Base class for checksums using the RFC 3961 simplified profile.
483     # Defines the checksum and verify methods.  Subclasses must
484     # define:
485     #   * macsize: Size of checksum in bytes
486     #   * enc: Profile of associated enctype
487
488     @classmethod
489     def checksum(cls, key, keyusage, text):
490         kc = cls.enc.derive(key, pack('>iB', keyusage, 0x99))
491         hmac = HMAC_HASH(kc.contents, text, cls.enc.hashalgo)
492         return hmac[:cls.macsize]
493
494     @classmethod
495     def verify(cls, key, keyusage, text, cksum):
496         if key.enctype != cls.enc.enctype:
497             raise ValueError('Wrong key type for checksum')
498         super(_SimplifiedChecksum, cls).verify(key, keyusage, text, cksum)
499
500
501 class _SHA1AES128(_SimplifiedChecksum):
502     macsize = 12
503     enc = _AES128CTS
504
505
506 class _SHA1AES256(_SimplifiedChecksum):
507     macsize = 12
508     enc = _AES256CTS
509
510
511 class _SHA1DES3(_SimplifiedChecksum):
512     macsize = 20
513     enc = _DES3CBC
514
515
516 class _HMACMD5(_ChecksumProfile):
517     @classmethod
518     def checksum(cls, key, keyusage, text):
519         ksign = HMAC_HASH(key.contents, b'signaturekey\0', hashes.MD5)
520         md5hash = SIMPLE_HASH(_RC4.usage_str(keyusage) + text, hashes.MD5)
521         return HMAC_HASH(ksign, md5hash, hashes.MD5)
522
523     @classmethod
524     def verify(cls, key, keyusage, text, cksum):
525         if key.enctype != Enctype.RC4:
526             raise ValueError('Wrong key type for checksum')
527         super(_HMACMD5, cls).verify(key, keyusage, text, cksum)
528
529
530 class _MD5(_ChecksumProfile):
531     @classmethod
532     def checksum(cls, key, keyusage, text):
533         # This is unkeyed!
534         return SIMPLE_HASH(text, hashes.MD5)
535
536
537 class _SHA1(_ChecksumProfile):
538     @classmethod
539     def checksum(cls, key, keyusage, text):
540         # This is unkeyed!
541         return SIMPLE_HASH(text, hashes.SHA1)
542
543
544 class _CRC32(_ChecksumProfile):
545     @classmethod
546     def checksum(cls, key, keyusage, text):
547         # This is unkeyed!
548         cksum = (~crc32(text, 0xffffffff)) & 0xffffffff
549         return pack('<I', cksum)
550
551
552 _enctype_table = {
553     Enctype.DES3: _DES3CBC,
554     Enctype.AES128: _AES128CTS,
555     Enctype.AES256: _AES256CTS,
556     Enctype.RC4: _RC4
557 }
558
559
560 _checksum_table = {
561     Cksumtype.SHA1_DES3: _SHA1DES3,
562     Cksumtype.SHA1_AES128: _SHA1AES128,
563     Cksumtype.SHA1_AES256: _SHA1AES256,
564     Cksumtype.HMAC_MD5: _HMACMD5,
565     Cksumtype.MD5: _MD5,
566     Cksumtype.SHA1: _SHA1,
567     Cksumtype.CRC32: _CRC32,
568 }
569
570
571 def _get_enctype_profile(enctype):
572     if enctype not in _enctype_table:
573         raise ValueError('Invalid enctype %d' % enctype)
574     return _enctype_table[enctype]
575
576
577 def _get_checksum_profile(cksumtype):
578     if cksumtype not in _checksum_table:
579         raise ValueError('Invalid cksumtype %d' % cksumtype)
580     return _checksum_table[cksumtype]
581
582
583 class Key(object):
584     def __init__(self, enctype, contents):
585         e = _get_enctype_profile(enctype)
586         if len(contents) != e.keysize:
587             raise ValueError('Wrong key length')
588         self.enctype = enctype
589         self.contents = contents
590
591
592 def seedsize(enctype):
593     e = _get_enctype_profile(enctype)
594     return e.seedsize
595
596
597 def random_to_key(enctype, seed):
598     e = _get_enctype_profile(enctype)
599     if len(seed) != e.seedsize:
600         raise ValueError('Wrong crypto seed length')
601     return e.random_to_key(seed)
602
603
604 def string_to_key(enctype, string, salt, params=None):
605     e = _get_enctype_profile(enctype)
606     return e.string_to_key(string, salt, params)
607
608
609 def encrypt(key, keyusage, plaintext, confounder=None):
610     e = _get_enctype_profile(key.enctype)
611     return e.encrypt(key, keyusage, plaintext, confounder)
612
613
614 def decrypt(key, keyusage, ciphertext):
615     # Throw InvalidChecksum on checksum failure.  Throw ValueError on
616     # invalid key enctype or malformed ciphertext.
617     e = _get_enctype_profile(key.enctype)
618     return e.decrypt(key, keyusage, ciphertext)
619
620
621 def prf(key, string):
622     e = _get_enctype_profile(key.enctype)
623     return e.prf(key, string)
624
625
626 def make_checksum(cksumtype, key, keyusage, text):
627     c = _get_checksum_profile(cksumtype)
628     return c.checksum(key, keyusage, text)
629
630
631 def verify_checksum(cksumtype, key, keyusage, text, cksum):
632     # Throw InvalidChecksum exception on checksum failure.  Throw
633     # ValueError on invalid cksumtype, invalid key enctype, or
634     # malformed checksum.
635     c = _get_checksum_profile(cksumtype)
636     c.verify(key, keyusage, text, cksum)
637
638
639 def prfplus(key, pepper, l):
640     # Produce l bytes of output using the RFC 6113 PRF+ function.
641     out = b''
642     count = 1
643     while len(out) < l:
644         out += prf(key, bytes([count]) + pepper)
645         count += 1
646     return out[:l]
647
648
649 def cf2(enctype, key1, key2, pepper1, pepper2):
650     # Combine two keys and two pepper strings to produce a result key
651     # of type enctype, using the RFC 6113 KRB-FX-CF2 function.
652     e = _get_enctype_profile(enctype)
653     return e.random_to_key(_xorbytes(prfplus(key1, pepper1, e.seedsize),
654                                      prfplus(key2, pepper2, e.seedsize)))
655
656 def h(hexstr):
657     return bytes.fromhex(hexstr)
658
659 class KcrytoTest(TestCase):
660     """kcrypto Test case."""
661
662     def test_aes128_crypr(self):
663         # AES128 encrypt and decrypt
664         kb = h('9062430C8CDA3388922E6D6A509F5B7A')
665         conf = h('94B491F481485B9A0678CD3C4EA386AD')
666         keyusage = 2
667         plain = b'9 bytesss'
668         ctxt = h('68FB9679601F45C78857B2BF820FD6E53ECA8D42FD4B1D7024A09205ABB7CD2E'
669                  'C26C355D2F')
670         k = Key(Enctype.AES128, kb)
671         self.assertEqual(encrypt(k, keyusage, plain, conf), ctxt)
672         self.assertEqual(decrypt(k, keyusage, ctxt), plain)
673
674     def test_aes256_crypt(self):
675         # AES256 encrypt and decrypt
676         kb = h('F1C795E9248A09338D82C3F8D5B567040B0110736845041347235B1404231398')
677         conf = h('E45CA518B42E266AD98E165E706FFB60')
678         keyusage = 4
679         plain = b'30 bytes bytes bytes bytes byt'
680         ctxt = h('D1137A4D634CFECE924DBC3BF6790648BD5CFF7DE0E7B99460211D0DAEF3D79A'
681                  '295C688858F3B34B9CBD6EEBAE81DAF6B734D4D498B6714F1C1D')
682         k = Key(Enctype.AES256, kb)
683         self.assertEqual(encrypt(k, keyusage, plain, conf), ctxt)
684         self.assertEqual(decrypt(k, keyusage, ctxt), plain)
685
686     def test_aes128_checksum(self):
687         # AES128 checksum
688         kb = h('9062430C8CDA3388922E6D6A509F5B7A')
689         keyusage = 3
690         plain = b'eight nine ten eleven twelve thirteen'
691         cksum = h('01A4B088D45628F6946614E3')
692         k = Key(Enctype.AES128, kb)
693         verify_checksum(Cksumtype.SHA1_AES128, k, keyusage, plain, cksum)
694
695     def test_aes256_checksum(self):
696         # AES256 checksum
697         kb = h('B1AE4CD8462AFF1677053CC9279AAC30B796FB81CE21474DD3DDBCFEA4EC76D7')
698         keyusage = 4
699         plain = b'fourteen'
700         cksum = h('E08739E3279E2903EC8E3836')
701         k = Key(Enctype.AES256, kb)
702         verify_checksum(Cksumtype.SHA1_AES256, k, keyusage, plain, cksum)
703
704     def test_aes128_string_to_key(self):
705         # AES128 string-to-key
706         string = b'password'
707         salt = b'ATHENA.MIT.EDUraeburn'
708         params = h('00000002')
709         kb = h('C651BF29E2300AC27FA469D693BDDA13')
710         k = string_to_key(Enctype.AES128, string, salt, params)
711         self.assertEqual(k.contents, kb)
712
713     def test_aes256_string_to_key(self):
714         # AES256 string-to-key
715         string = b'X' * 64
716         salt = b'pass phrase equals block size'
717         params = h('000004B0')
718         kb = h('89ADEE3608DB8BC71F1BFBFE459486B05618B70CBAE22092534E56C553BA4B34')
719         k = string_to_key(Enctype.AES256, string, salt, params)
720         self.assertEqual(k.contents, kb)
721
722     def test_aes128_prf(self):
723         # AES128 prf
724         kb = h('77B39A37A868920F2A51F9DD150C5717')
725         k = string_to_key(Enctype.AES128, b'key1', b'key1')
726         self.assertEqual(prf(k, b'\x01\x61'), kb)
727
728     def test_aes256_prf(self):
729         # AES256 prf
730         kb = h('0D674DD0F9A6806525A4D92E828BD15A')
731         k = string_to_key(Enctype.AES256, b'key2', b'key2')
732         self.assertEqual(prf(k, b'\x02\x62'), kb)
733
734     def test_aes128_cf2(self):
735         # AES128 cf2
736         kb = h('97DF97E4B798B29EB31ED7280287A92A')
737         k1 = string_to_key(Enctype.AES128, b'key1', b'key1')
738         k2 = string_to_key(Enctype.AES128, b'key2', b'key2')
739         k = cf2(Enctype.AES128, k1, k2, b'a', b'b')
740         self.assertEqual(k.contents, kb)
741
742     def test_aes256_cf2(self):
743         # AES256 cf2
744         kb = h('4D6CA4E629785C1F01BAF55E2E548566B9617AE3A96868C337CB93B5E72B1C7B')
745         k1 = string_to_key(Enctype.AES256, b'key1', b'key1')
746         k2 = string_to_key(Enctype.AES256, b'key2', b'key2')
747         k = cf2(Enctype.AES256, k1, k2, b'a', b'b')
748         self.assertEqual(k.contents, kb)
749
750     def test_des3_crypt(self):
751         # DES3 encrypt and decrypt
752         kb = h('0DD52094E0F41CECCB5BE510A764B35176E3981332F1E598')
753         conf = h('94690A17B2DA3C9B')
754         keyusage = 3
755         plain = b'13 bytes byte'
756         ctxt = h('839A17081ECBAFBCDC91B88C6955DD3C4514023CF177B77BF0D0177A16F705E8'
757                  '49CB7781D76A316B193F8D30')
758         k = Key(Enctype.DES3, kb)
759         self.assertEqual(encrypt(k, keyusage, plain, conf), ctxt)
760         self.assertEqual(decrypt(k, keyusage, ctxt), _zeropad(plain, 8))
761
762     def test_des3_string_to_key(self):
763         # DES3 string-to-key
764         string = b'password'
765         salt = b'ATHENA.MIT.EDUraeburn'
766         kb = h('850BB51358548CD05E86768C313E3BFEF7511937DCF72C3E')
767         k = string_to_key(Enctype.DES3, string, salt)
768         self.assertEqual(k.contents, kb)
769
770     def test_des3_checksum(self):
771         # DES3 checksum
772         kb = h('7A25DF8992296DCEDA0E135BC4046E2375B3C14C98FBC162')
773         keyusage = 2
774         plain = b'six seven'
775         cksum = h('0EEFC9C3E049AABC1BA5C401677D9AB699082BB4')
776         k = Key(Enctype.DES3, kb)
777         verify_checksum(Cksumtype.SHA1_DES3, k, keyusage, plain, cksum)
778
779     def test_des3_cf2(self):
780         # DES3 cf2
781         kb = h('E58F9EB643862C13AD38E529313462A7F73E62834FE54A01')
782         k1 = string_to_key(Enctype.DES3, b'key1', b'key1')
783         k2 = string_to_key(Enctype.DES3, b'key2', b'key2')
784         k = cf2(Enctype.DES3, k1, k2, b'a', b'b')
785         self.assertEqual(k.contents, kb)
786
787     def test_rc4_crypt(self):
788         # RC4 encrypt and decrypt
789         kb = h('68F263DB3FCE15D031C9EAB02D67107A')
790         conf = h('37245E73A45FBF72')
791         keyusage = 4
792         plain = b'30 bytes bytes bytes bytes byt'
793         ctxt = h('95F9047C3AD75891C2E9B04B16566DC8B6EB9CE4231AFB2542EF87A7B5A0F260'
794                  'A99F0460508DE0CECC632D07C354124E46C5D2234EB8')
795         k = Key(Enctype.RC4, kb)
796         self.assertEqual(encrypt(k, keyusage, plain, conf), ctxt)
797         self.assertEqual(decrypt(k, keyusage, ctxt), plain)
798
799     def test_rc4_string_to_key(self):
800         # RC4 string-to-key
801         string = b'foo'
802         kb = h('AC8E657F83DF82BEEA5D43BDAF7800CC')
803         k = string_to_key(Enctype.RC4, string, None)
804         self.assertEqual(k.contents, kb)
805
806     def test_rc4_checksum(self):
807         # RC4 checksum
808         kb = h('F7D3A155AF5E238A0B7A871A96BA2AB2')
809         keyusage = 6
810         plain = b'seventeen eighteen nineteen twenty'
811         cksum = h('EB38CC97E2230F59DA4117DC5859D7EC')
812         k = Key(Enctype.RC4, kb)
813         verify_checksum(Cksumtype.HMAC_MD5, k, keyusage, plain, cksum)
814
815     def test_rc4_cf2(self):
816         # RC4 cf2
817         kb = h('24D7F6B6BAE4E5C00D2082C5EBAB3672')
818         k1 = string_to_key(Enctype.RC4, b'key1', b'key1')
819         k2 = string_to_key(Enctype.RC4, b'key2', b'key2')
820         k = cf2(Enctype.RC4, k1, k2, b'a', b'b')
821         self.assertEqual(k.contents, kb)
822
823     def _test_md5_unkeyed_checksum(self, etype, usage):
824         # MD5 unkeyed checksum
825         pw = b'pwd'
826         salt = b'bytes'
827         key = string_to_key(etype, pw, salt)
828         plain = b'seventeen eighteen nineteen twenty'
829         cksum = h('9d9588cdef3a8cefc9d2c208d978f60c')
830         verify_checksum(Cksumtype.MD5, key, usage, plain, cksum)
831
832     def test_md5_unkeyed_checksum_des3_usage_40(self):
833         return self._test_md5_unkeyed_checksum(Enctype.DES3, 40)
834
835     def test_md5_unkeyed_checksum_des3_usage_50(self):
836         return self._test_md5_unkeyed_checksum(Enctype.DES3, 50)
837
838     def test_md5_unkeyed_checksum_rc4_usage_40(self):
839         return self._test_md5_unkeyed_checksum(Enctype.RC4, 40)
840
841     def test_md5_unkeyed_checksum_rc4_usage_50(self):
842         return self._test_md5_unkeyed_checksum(Enctype.RC4, 50)
843
844     def test_md5_unkeyed_checksum_aes128_usage_40(self):
845         return self._test_md5_unkeyed_checksum(Enctype.AES128, 40)
846
847     def test_md5_unkeyed_checksum_aes128_usage_50(self):
848         return self._test_md5_unkeyed_checksum(Enctype.AES128, 50)
849
850     def test_md5_unkeyed_checksum_aes256_usage_40(self):
851         return self._test_md5_unkeyed_checksum(Enctype.AES256, 40)
852
853     def test_md5_unkeyed_checksum_aes256_usage_50(self):
854         return self._test_md5_unkeyed_checksum(Enctype.AES256, 50)
855
856     def _test_sha1_unkeyed_checksum(self, etype, usage):
857         # SHA1 unkeyed checksum
858         pw = b'password'
859         salt = b'salt'
860         key = string_to_key(etype, pw, salt)
861         plain = b'twenty nineteen eighteen seventeen'
862         cksum = h('381c870d8875d1913555de19af5c885fd27b7da9')
863         verify_checksum(Cksumtype.SHA1, key, usage, plain, cksum)
864
865     def test_sha1_unkeyed_checksum_des3_usage_40(self):
866         return self._test_sha1_unkeyed_checksum(Enctype.DES3, 40)
867
868     def test_sha1_unkeyed_checksum_des3_usage_50(self):
869         return self._test_sha1_unkeyed_checksum(Enctype.DES3, 50)
870
871     def test_sha1_unkeyed_checksum_rc4_usage_40(self):
872         return self._test_sha1_unkeyed_checksum(Enctype.RC4, 40)
873
874     def test_sha1_unkeyed_checksum_rc4_usage_50(self):
875         return self._test_sha1_unkeyed_checksum(Enctype.RC4, 50)
876
877     def test_sha1_unkeyed_checksum_aes128_usage_40(self):
878         return self._test_sha1_unkeyed_checksum(Enctype.AES128, 40)
879
880     def test_sha1_unkeyed_checksum_aes128_usage_50(self):
881         return self._test_sha1_unkeyed_checksum(Enctype.AES128, 50)
882
883     def test_sha1_unkeyed_checksum_aes256_usage_40(self):
884         return self._test_sha1_unkeyed_checksum(Enctype.AES256, 40)
885
886     def test_sha1_unkeyed_checksum_aes256_usage_50(self):
887         return self._test_sha1_unkeyed_checksum(Enctype.AES256, 50)
888
889     def _test_crc32_unkeyed_checksum(self, etype, usage):
890         # CRC32 unkeyed checksum
891         pw = b'password'
892         salt = b'salt'
893         key = string_to_key(etype, pw, salt)
894         plain = b'africa america asia australia europe'
895         cksum = h('ce595a53')
896         verify_checksum(Cksumtype.CRC32, key, usage, plain, cksum)
897
898     def test_crc32_unkeyed_checksum_des3_usage_40(self):
899         return self._test_crc32_unkeyed_checksum(Enctype.DES3, 40)
900
901     def test_crc32_unkeyed_checksum_des3_usage_50(self):
902         return self._test_crc32_unkeyed_checksum(Enctype.DES3, 50)
903
904     def test_crc32_unkeyed_checksum_rc4_usage_40(self):
905         return self._test_crc32_unkeyed_checksum(Enctype.RC4, 40)
906
907     def test_crc32_unkeyed_checksum_rc4_usage_50(self):
908         return self._test_crc32_unkeyed_checksum(Enctype.RC4, 50)
909
910     def test_crc32_unkeyed_checksum_aes128_usage_40(self):
911         return self._test_crc32_unkeyed_checksum(Enctype.AES128, 40)
912
913     def test_crc32_unkeyed_checksum_aes128_usage_50(self):
914         return self._test_crc32_unkeyed_checksum(Enctype.AES128, 50)
915
916     def test_crc32_unkeyed_checksum_aes256_usage_40(self):
917         return self._test_crc32_unkeyed_checksum(Enctype.AES256, 40)
918
919     def test_crc32_unkeyed_checksum_aes256_usage_50(self):
920         return self._test_crc32_unkeyed_checksum(Enctype.AES256, 50)
921
922
923 if __name__ == "__main__":
924     import unittest
925     unittest.main()