PEP8: fix E226: missing whitespace around arithmetic operator
[samba.git] / python / samba / tests / password_hash.py
1 # Tests for Tests for source4/dsdb/samdb/ldb_modules/password_hash.c
2 #
3 # Copyright (C) Catalyst IT Ltd. 2017
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """
20 Base class for tests for source4/dsdb/samdb/ldb_modules/password_hash.c
21 """
22
23 from samba.credentials import Credentials
24 from samba.samdb import SamDB
25 from samba.auth import system_session
26 from samba.tests import TestCase
27 from samba.ndr import ndr_unpack
28 from samba.dcerpc import drsblobs
29 from samba.dcerpc.samr import DOMAIN_PASSWORD_STORE_CLEARTEXT
30 from samba.dsdb import UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED
31 from samba.tests import delete_force
32 from samba.tests.password_test import PasswordCommon
33 import ldb
34 import samba
35 import binascii
36 from hashlib import md5
37 import crypt
38 from samba.compat import text_type
39
40
41 USER_NAME = "PasswordHashTestUser"
42 USER_PASS = samba.generate_random_password(32, 32)
43 UPN       = "PWHash@User.Principle"
44
45 # Get named package from the passed supplemental credentials
46 #
47 # returns the package and it's position within the supplemental credentials
48 def get_package(sc, name):
49     if sc is None:
50         return None
51
52     idx = 0
53     for p in sc.sub.packages:
54         idx += 1
55         if name == p.name:
56             return (idx, p)
57
58     return None
59
60 # Calculate the MD5 password digest from the supplied user, realm and password
61 #
62 def calc_digest(user, realm, password):
63
64     data = "%s:%s:%s" % (user, realm, password)
65     if isinstance(data, text_type):
66         data = data.encode('utf8')
67
68     return md5(data).hexdigest()
69
70
71 class PassWordHashTests(TestCase):
72
73     def setUp(self):
74         self.lp = samba.tests.env_loadparm()
75         super(PassWordHashTests, self).setUp()
76
77     def set_store_cleartext(self, cleartext):
78         # get the current pwdProperties
79         pwdProperties = self.ldb.get_pwdProperties()
80         # update the clear-text properties flag
81         props = int(pwdProperties)
82         if cleartext:
83             props |= DOMAIN_PASSWORD_STORE_CLEARTEXT
84         else:
85             props &= ~DOMAIN_PASSWORD_STORE_CLEARTEXT
86         self.ldb.set_pwdProperties(str(props))
87
88     # Add a user to ldb, this will exercise the password_hash code
89     # and calculate the appropriate supplemental credentials
90     def add_user(self, options=None, clear_text=False, ldb=None):
91         # set any needed options
92         if options is not None:
93             for (option, value) in options:
94                 self.lp.set(option, value)
95
96         if ldb is None:
97             self.creds = Credentials()
98             self.session = system_session()
99             self.creds.guess(self.lp)
100             self.session = system_session()
101             self.ldb = SamDB(session_info=self.session,
102                              credentials=self.creds,
103                              lp=self.lp)
104         else:
105             self.ldb = ldb
106
107         res = self.ldb.search(base=self.ldb.get_config_basedn(),
108                               expression="ncName=%s" % self.ldb.get_default_basedn(),
109                               attrs=["nETBIOSName"])
110         self.netbios_domain = res[0]["nETBIOSName"][0]
111         self.dns_domain = self.ldb.domain_dns_name()
112
113
114         # Gets back the basedn
115         base_dn = self.ldb.domain_dn()
116
117         # Gets back the configuration basedn
118         configuration_dn = self.ldb.get_config_basedn().get_linearized()
119
120         # permit password changes during this test
121         PasswordCommon.allow_password_changes(self, self.ldb)
122
123         self.base_dn = self.ldb.domain_dn()
124
125         account_control = 0
126         if clear_text:
127             # Restore the current domain setting on exit.
128             pwdProperties = self.ldb.get_pwdProperties()
129             self.addCleanup(self.ldb.set_pwdProperties, pwdProperties)
130             # Update the domain setting
131             self.set_store_cleartext(clear_text)
132             account_control |= UF_ENCRYPTED_TEXT_PASSWORD_ALLOWED
133
134         # (Re)adds the test user USER_NAME with password USER_PASS
135         # and userPrincipalName UPN
136         delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
137         self.ldb.add({
138              "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
139              "objectclass": "user",
140              "sAMAccountName": USER_NAME,
141              "userPassword": USER_PASS,
142              "userPrincipalName": UPN,
143              "userAccountControl": str(account_control)
144         })
145
146     # Get the supplemental credentials for the user under test
147     def get_supplemental_creds(self):
148         base = "cn=" + USER_NAME + ",cn=users," + self.base_dn
149         res = self.ldb.search(scope=ldb.SCOPE_BASE,
150                               base=base,
151                               attrs=["supplementalCredentials"])
152         self.assertIs(True, len(res) > 0)
153         obj = res[0]
154         sc_blob = obj["supplementalCredentials"][0]
155         sc = ndr_unpack(drsblobs.supplementalCredentialsBlob, sc_blob)
156         return sc
157
158     # Calculate and validate a Wdigest value
159     def check_digest(self, user, realm, password,  digest):
160         expected = calc_digest(user, realm, password)
161         actual = binascii.hexlify(bytearray(digest))
162         error = "Digest expected[%s], actual[%s], " \
163                 "user[%s], realm[%s], pass[%s]" % \
164                 (expected, actual, user, realm, password)
165         self.assertEquals(expected, actual, error)
166
167     # Check all of the 29 expected WDigest values
168     #
169     def check_wdigests(self, digests):
170
171         self.assertEquals(29, digests.num_hashes)
172
173         # Using the n-1 pattern in the array indexes to make it easier
174         # to check the tests against the spec and the samba-tool user tests.
175         self.check_digest(USER_NAME,
176                           self.netbios_domain,
177                           USER_PASS,
178                           digests.hashes[1 - 1].hash)
179         self.check_digest(USER_NAME.lower(),
180                           self.netbios_domain.lower(),
181                           USER_PASS,
182                           digests.hashes[2 - 1].hash)
183         self.check_digest(USER_NAME.upper(),
184                           self.netbios_domain.upper(),
185                           USER_PASS,
186                           digests.hashes[3 - 1].hash)
187         self.check_digest(USER_NAME,
188                           self.netbios_domain.upper(),
189                           USER_PASS,
190                           digests.hashes[4 - 1].hash)
191         self.check_digest(USER_NAME,
192                           self.netbios_domain.lower(),
193                           USER_PASS,
194                           digests.hashes[5 - 1].hash)
195         self.check_digest(USER_NAME.upper(),
196                           self.netbios_domain.lower(),
197                           USER_PASS,
198                           digests.hashes[6 - 1].hash)
199         self.check_digest(USER_NAME.lower(),
200                           self.netbios_domain.upper(),
201                           USER_PASS,
202                           digests.hashes[7 - 1].hash)
203         self.check_digest(USER_NAME,
204                           self.dns_domain,
205                           USER_PASS,
206                           digests.hashes[8 - 1].hash)
207         self.check_digest(USER_NAME.lower(),
208                           self.dns_domain.lower(),
209                           USER_PASS,
210                           digests.hashes[9 - 1].hash)
211         self.check_digest(USER_NAME.upper(),
212                           self.dns_domain.upper(),
213                           USER_PASS,
214                           digests.hashes[10 - 1].hash)
215         self.check_digest(USER_NAME,
216                           self.dns_domain.upper(),
217                           USER_PASS,
218                           digests.hashes[11 - 1].hash)
219         self.check_digest(USER_NAME,
220                           self.dns_domain.lower(),
221                           USER_PASS,
222                           digests.hashes[12 - 1].hash)
223         self.check_digest(USER_NAME.upper(),
224                           self.dns_domain.lower(),
225                           USER_PASS,
226                           digests.hashes[13 - 1].hash)
227         self.check_digest(USER_NAME.lower(),
228                           self.dns_domain.upper(),
229                           USER_PASS,
230                           digests.hashes[14 - 1].hash)
231         self.check_digest(UPN,
232                           "",
233                           USER_PASS,
234                           digests.hashes[15 - 1].hash)
235         self.check_digest(UPN.lower(),
236                           "",
237                           USER_PASS,
238                           digests.hashes[16 - 1].hash)
239         self.check_digest(UPN.upper(),
240                           "",
241                           USER_PASS,
242                           digests.hashes[17 - 1].hash)
243
244         name = "%s\\%s" % (self.netbios_domain, USER_NAME)
245         self.check_digest(name,
246                           "",
247                           USER_PASS,
248                           digests.hashes[18 - 1].hash)
249
250         name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
251         self.check_digest(name,
252                           "",
253                           USER_PASS,
254                           digests.hashes[19 - 1].hash)
255
256         name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
257         self.check_digest(name,
258                           "",
259                           USER_PASS,
260                           digests.hashes[20 - 1].hash)
261         self.check_digest(USER_NAME,
262                           "Digest",
263                           USER_PASS,
264                           digests.hashes[21 - 1].hash)
265         self.check_digest(USER_NAME.lower(),
266                           "Digest",
267                           USER_PASS,
268                           digests.hashes[22 - 1].hash)
269         self.check_digest(USER_NAME.upper(),
270                           "Digest",
271                           USER_PASS,
272                           digests.hashes[23 - 1].hash)
273         self.check_digest(UPN,
274                           "Digest",
275                           USER_PASS,
276                           digests.hashes[24 - 1].hash)
277         self.check_digest(UPN.lower(),
278                           "Digest",
279                           USER_PASS,
280                           digests.hashes[25 - 1].hash)
281         self.check_digest(UPN.upper(),
282                           "Digest",
283                           USER_PASS,
284                           digests.hashes[26 - 1].hash)
285         name = "%s\\%s" % (self.netbios_domain, USER_NAME)
286         self.check_digest(name,
287                           "Digest",
288                           USER_PASS,
289                           digests.hashes[27 - 1].hash)
290
291         name = "%s\\%s" % (self.netbios_domain.lower(), USER_NAME.lower())
292         self.check_digest(name,
293                           "Digest",
294                           USER_PASS,
295                           digests.hashes[28 - 1].hash)
296
297         name = "%s\\%s" % (self.netbios_domain.upper(), USER_NAME.upper())
298         self.check_digest(name,
299                           "Digest",
300                           USER_PASS,
301                           digests.hashes[29 - 1].hash)
302
303     def checkUserPassword(self, up, expected):
304
305         # Check we've received the correct number of hashes
306         self.assertEquals(len(expected), up.num_hashes)
307
308         i = 0
309         for (tag, alg, rounds) in expected:
310             self.assertEquals(tag, up.hashes[i].scheme)
311
312             data = up.hashes[i].value.split("$")
313             # Check we got the expected crypt algorithm
314             self.assertEquals(alg, data[1])
315
316             if rounds is None:
317                 cmd = "$%s$%s" % (alg, data[2])
318             else:
319                 cmd = "$%s$rounds=%d$%s" % (alg, rounds, data[3])
320
321             # Calculate the expected hash value
322             expected = crypt.crypt(USER_PASS, cmd)
323             self.assertEquals(expected, up.hashes[i].value)
324             i += 1
325
326     # Check that the correct nt_hash was stored for userPassword
327     def checkNtHash(self, password, nt_hash):
328         creds = Credentials()
329         creds.set_anonymous()
330         creds.set_password(password)
331         expected = creds.get_nt_hash()
332         actual = bytearray(nt_hash)
333         self.assertEquals(expected, actual)