CVE-2023-4154 s4:dsdb:tests: Fix code spelling
[samba.git] / source4 / dsdb / tests / python / token_group.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 # test tokengroups attribute against internal token calculation
4
5 import optparse
6 import sys
7 import os
8
9 sys.path.insert(0, "bin/python")
10 import samba
11
12 from samba.tests.subunitrun import SubunitOptions, TestProgram
13
14 import samba.getopt as options
15
16 from samba.auth import system_session
17 from samba import ldb, dsdb
18 from samba.samdb import SamDB
19 from samba.auth import AuthContext
20 from samba.ndr import ndr_unpack
21 from samba import gensec
22 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS, AUTO_USE_KERBEROS
23 from samba.dsdb import GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP
24 import samba.tests
25 from samba.tests import delete_force
26 from samba.dcerpc import security
27 from samba.auth import AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES, AUTH_SESSION_INFO_NTLM
28
29
30 parser = optparse.OptionParser("token_group.py [options] <host>")
31 sambaopts = options.SambaOptions(parser)
32 parser.add_option_group(sambaopts)
33 parser.add_option_group(options.VersionOptions(parser))
34 # use command line creds if available
35 credopts = options.CredentialsOptions(parser)
36 parser.add_option_group(credopts)
37 subunitopts = SubunitOptions(parser)
38 parser.add_option_group(subunitopts)
39 opts, args = parser.parse_args()
40
41 if len(args) < 1:
42     parser.print_usage()
43     sys.exit(1)
44
45 url = args[0]
46
47 lp = sambaopts.get_loadparm()
48 creds = credopts.get_credentials(lp)
49 creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
50
51
52 def closure(vSet, wSet, aSet):
53     for edge in aSet:
54         start, end = edge
55         if start in wSet:
56             if end not in wSet and end in vSet:
57                 wSet.add(end)
58                 closure(vSet, wSet, aSet)
59
60
61 class StaticTokenTest(samba.tests.TestCase):
62
63     def setUp(self):
64         super(StaticTokenTest, self).setUp()
65
66         self.assertNotEqual(creds.get_kerberos_state(), AUTO_USE_KERBEROS)
67
68         self.ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
69         self.base_dn = self.ldb.domain_dn()
70
71         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
72         self.assertEqual(len(res), 1)
73
74         self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
75
76         session_info_flags = (AUTH_SESSION_INFO_DEFAULT_GROUPS |
77                               AUTH_SESSION_INFO_AUTHENTICATED |
78                               AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
79         if creds.get_kerberos_state() == DONT_USE_KERBEROS:
80             session_info_flags |= AUTH_SESSION_INFO_NTLM
81
82         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
83                                           session_info_flags=session_info_flags)
84
85         token = session.security_token
86         self.user_sids = []
87         for s in token.sids:
88             self.user_sids.append(str(s))
89
90         # Add asserted identity and Claims Valid for Kerberos
91         if creds.get_kerberos_state() == MUST_USE_KERBEROS:
92             self.user_sids.append(str(security.SID_AUTHENTICATION_AUTHORITY_ASSERTED_IDENTITY))
93             self.user_sids.append(str(security.SID_CLAIMS_VALID))
94
95
96     def test_rootDSE_tokenGroups(self):
97         """Testing rootDSE tokengroups against internal calculation"""
98         if not url.startswith("ldap"):
99             self.fail(msg="This test is only valid on ldap")
100
101         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
102         self.assertEqual(len(res), 1)
103
104         print("Getting tokenGroups from rootDSE")
105         tokengroups = []
106         for sid in res[0]['tokenGroups']:
107             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
108
109         sidset1 = set(tokengroups)
110         sidset2 = set(self.user_sids)
111         if len(sidset1.symmetric_difference(sidset2)):
112             print("token sids don't match")
113             print("tokengroups: %s" % tokengroups)
114             print("calculated : %s" % self.user_sids)
115             print("difference : %s" % sidset1.symmetric_difference(sidset2))
116             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
117
118     def test_dn_tokenGroups(self):
119         print("Getting tokenGroups from user DN")
120         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
121         self.assertEqual(len(res), 1)
122
123         dn_tokengroups = []
124         for sid in res[0]['tokenGroups']:
125             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
126
127         sidset1 = set(dn_tokengroups)
128         sidset2 = set(self.user_sids)
129
130         # The tokenGroups is just a subset of the user_sids
131         # so we don't check symmetric_difference() here.
132         if len(sidset1.difference(sidset2)):
133             print("dn token sids no subset of user token")
134             print("tokengroups: %s" % dn_tokengroups)
135             print("user sids : %s" % self.user_sids)
136             print("difference : %s" % sidset1.difference(sidset2))
137             self.fail(msg="DN tokenGroups no subset of full user token")
138
139         missing_sidset = sidset2.difference(sidset1)
140
141         extra_sids = []
142         extra_sids.append(self.user_sids[0])
143         extra_sids.append(security.SID_WORLD)
144         extra_sids.append(security.SID_NT_NETWORK)
145         extra_sids.append(security.SID_NT_AUTHENTICATED_USERS)
146         extra_sids.append(security.SID_BUILTIN_PREW2K)
147         if creds.get_kerberos_state() == MUST_USE_KERBEROS:
148             extra_sids.append(security.SID_AUTHENTICATION_AUTHORITY_ASSERTED_IDENTITY)
149             extra_sids.append(security.SID_CLAIMS_VALID)
150         if creds.get_kerberos_state() == DONT_USE_KERBEROS:
151             extra_sids.append(security.SID_NT_NTLM_AUTHENTICATION)
152
153         extra_sidset = set(extra_sids)
154
155         if len(missing_sidset.symmetric_difference(extra_sidset)):
156             print("dn token sids unexpected")
157             print("tokengroups: %s" % dn_tokengroups)
158             print("user sids: %s" % self.user_sids)
159             print("actual difference: %s" % missing_sidset)
160             print("expected difference: %s" % extra_sidset)
161             print("unexpected difference : %s" %
162                     missing_sidset.symmetric_difference(extra_sidset))
163             self.fail(msg="DN tokenGroups unexpected difference to full user token")
164
165     def test_pac_groups(self):
166         if creds.get_kerberos_state() != MUST_USE_KERBEROS:
167             self.skipTest("Kerberos disabled, skipping PAC test")
168
169         settings = {}
170         settings["lp_ctx"] = lp
171         settings["target_hostname"] = lp.get("netbios name")
172
173         gensec_client = gensec.Security.start_client(settings)
174         gensec_client.set_credentials(creds)
175         gensec_client.want_feature(gensec.FEATURE_SEAL)
176         gensec_client.start_mech_by_sasl_name("GSSAPI")
177
178         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
179
180         gensec_server = gensec.Security.start_server(settings, auth_context)
181         machine_creds = Credentials()
182         machine_creds.guess(lp)
183         machine_creds.set_machine_account(lp)
184         gensec_server.set_credentials(machine_creds)
185
186         gensec_server.want_feature(gensec.FEATURE_SEAL)
187         gensec_server.start_mech_by_sasl_name("GSSAPI")
188
189         client_finished = False
190         server_finished = False
191         server_to_client = b""
192
193         # Run the actual call loop.
194         while not client_finished and not server_finished:
195             if not client_finished:
196                 print("running client gensec_update")
197                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
198             if not server_finished:
199                 print("running server gensec_update")
200                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
201
202         session = gensec_server.session_info()
203
204         token = session.security_token
205         pac_sids = []
206         for s in token.sids:
207             pac_sids.append(str(s))
208
209         sidset1 = set(pac_sids)
210         sidset2 = set(self.user_sids)
211         if len(sidset1.symmetric_difference(sidset2)):
212             print("token sids don't match")
213             print("pac sids: %s" % pac_sids)
214             print("user sids : %s" % self.user_sids)
215             print("difference : %s" % sidset1.symmetric_difference(sidset2))
216             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
217
218
219 class DynamicTokenTest(samba.tests.TestCase):
220
221     def get_creds(self, target_username, target_password):
222         creds_tmp = Credentials()
223         creds_tmp.set_username(target_username)
224         creds_tmp.set_password(target_password)
225         creds_tmp.set_domain(creds.get_domain())
226         creds_tmp.set_realm(creds.get_realm())
227         creds_tmp.set_kerberos_state(creds.get_kerberos_state())
228         creds_tmp.set_workstation(creds.get_workstation())
229         creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
230                                       | gensec.FEATURE_SEAL)
231         return creds_tmp
232
233     def get_ldb_connection(self, target_username, target_password):
234         creds_tmp = self.get_creds(target_username, target_password)
235         ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
236         return ldb_target
237
238     def setUp(self):
239         super(DynamicTokenTest, self).setUp()
240
241         self.assertNotEqual(creds.get_kerberos_state(), AUTO_USE_KERBEROS)
242
243         self.admin_ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
244
245         self.base_dn = self.admin_ldb.domain_dn()
246
247         self.test_user = "tokengroups_user1"
248         self.test_user_pass = "samba123@"
249         self.admin_ldb.newuser(self.test_user, self.test_user_pass)
250         self.test_group0 = "tokengroups_group0"
251         self.admin_ldb.newgroup(self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
252         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group0, self.base_dn),
253                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
254         self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
255
256         self.admin_ldb.add_remove_group_members(self.test_group0, [self.test_user],
257                                                 add_members_operation=True)
258
259         self.test_group1 = "tokengroups_group1"
260         self.admin_ldb.newgroup(self.test_group1, grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
261         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group1, self.base_dn),
262                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
263         self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
264
265         self.admin_ldb.add_remove_group_members(self.test_group1, [self.test_user],
266                                                 add_members_operation=True)
267
268         self.test_group2 = "tokengroups_group2"
269         self.admin_ldb.newgroup(self.test_group2, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
270
271         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group2, self.base_dn),
272                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
273         self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
274
275         self.admin_ldb.add_remove_group_members(self.test_group2, [self.test_user],
276                                                 add_members_operation=True)
277
278         self.test_group3 = "tokengroups_group3"
279         self.admin_ldb.newgroup(self.test_group3, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
280
281         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group3, self.base_dn),
282                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
283         self.test_group3_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
284
285         self.admin_ldb.add_remove_group_members(self.test_group3, [self.test_group1],
286                                                 add_members_operation=True)
287
288         self.test_group4 = "tokengroups_group4"
289         self.admin_ldb.newgroup(self.test_group4, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
290
291         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group4, self.base_dn),
292                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
293         self.test_group4_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
294
295         self.admin_ldb.add_remove_group_members(self.test_group4, [self.test_group3],
296                                                 add_members_operation=True)
297
298         self.test_group5 = "tokengroups_group5"
299         self.admin_ldb.newgroup(self.test_group5, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
300
301         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group5, self.base_dn),
302                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
303         self.test_group5_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
304
305         self.admin_ldb.add_remove_group_members(self.test_group5, [self.test_group4],
306                                                 add_members_operation=True)
307
308         self.test_group6 = "tokengroups_group6"
309         self.admin_ldb.newgroup(self.test_group6, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
310
311         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group6, self.base_dn),
312                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
313         self.test_group6_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
314
315         self.admin_ldb.add_remove_group_members(self.test_group6, [self.test_user],
316                                                 add_members_operation=True)
317
318         self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)
319
320         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
321         self.assertEqual(len(res), 1)
322
323         self.user_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0])
324         self.user_sid_dn = "<SID=%s>" % str(self.user_sid)
325
326         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
327         self.assertEqual(len(res), 1)
328
329         self.test_user_dn = res[0].dn
330
331         session_info_flags = (AUTH_SESSION_INFO_DEFAULT_GROUPS |
332                               AUTH_SESSION_INFO_AUTHENTICATED |
333                               AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
334
335         if creds.get_kerberos_state() == DONT_USE_KERBEROS:
336             session_info_flags |= AUTH_SESSION_INFO_NTLM
337
338         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
339                                           session_info_flags=session_info_flags)
340
341         token = session.security_token
342         self.user_sids = []
343         for s in token.sids:
344             self.user_sids.append(str(s))
345
346         # Add asserted identity and Claims Valid for Kerberos
347         if creds.get_kerberos_state() == MUST_USE_KERBEROS:
348             self.user_sids.append(str(security.SID_AUTHENTICATION_AUTHORITY_ASSERTED_IDENTITY))
349             self.user_sids.append(str(security.SID_CLAIMS_VALID))
350
351     def tearDown(self):
352         super(DynamicTokenTest, self).tearDown()
353         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
354                      (self.test_user, "cn=users", self.base_dn))
355         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
356                      (self.test_group0, "cn=users", self.base_dn))
357         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
358                      (self.test_group1, "cn=users", self.base_dn))
359         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
360                      (self.test_group2, "cn=users", self.base_dn))
361         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
362                      (self.test_group3, "cn=users", self.base_dn))
363         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
364                      (self.test_group4, "cn=users", self.base_dn))
365         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
366                      (self.test_group5, "cn=users", self.base_dn))
367         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
368                      (self.test_group6, "cn=users", self.base_dn))
369
370     def test_rootDSE_tokenGroups(self):
371         """Testing rootDSE tokengroups against internal calculation"""
372         if not url.startswith("ldap"):
373             self.fail(msg="This test is only valid on ldap")
374
375         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
376         self.assertEqual(len(res), 1)
377
378         print("Getting tokenGroups from rootDSE")
379         tokengroups = []
380         for sid in res[0]['tokenGroups']:
381             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
382
383         sidset1 = set(tokengroups)
384         sidset2 = set(self.user_sids)
385         if len(sidset1.symmetric_difference(sidset2)):
386             print("token sids don't match")
387             print("tokengroups: %s" % tokengroups)
388             print("calculated : %s" % self.user_sids)
389             print("difference : %s" % sidset1.symmetric_difference(sidset2))
390             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
391
392     def test_dn_tokenGroups(self):
393         print("Getting tokenGroups from user DN")
394         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
395         self.assertEqual(len(res), 1)
396
397         dn_tokengroups = []
398         for sid in res[0]['tokenGroups']:
399             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
400
401         sidset1 = set(dn_tokengroups)
402         sidset2 = set(self.user_sids)
403
404         # The tokenGroups is just a subset of the user_sids
405         # so we don't check symmetric_difference() here.
406         if len(sidset1.difference(sidset2)):
407             print("dn token sids no subset of user token")
408             print("tokengroups: %s" % dn_tokengroups)
409             print("user sids : %s" % self.user_sids)
410             print("difference : %s" % sidset1.difference(sidset2))
411             self.fail(msg="DN tokenGroups no subset of full user token")
412
413         missing_sidset = sidset2.difference(sidset1)
414
415         extra_sids = []
416         extra_sids.append(self.user_sids[0])
417         extra_sids.append(security.SID_WORLD)
418         extra_sids.append(security.SID_NT_NETWORK)
419         extra_sids.append(security.SID_NT_AUTHENTICATED_USERS)
420         extra_sids.append(security.SID_BUILTIN_PREW2K)
421         if creds.get_kerberos_state() == MUST_USE_KERBEROS:
422             extra_sids.append(security.SID_AUTHENTICATION_AUTHORITY_ASSERTED_IDENTITY)
423             extra_sids.append(security.SID_CLAIMS_VALID)
424         if creds.get_kerberos_state() == DONT_USE_KERBEROS:
425             extra_sids.append(security.SID_NT_NTLM_AUTHENTICATION)
426
427         extra_sidset = set(extra_sids)
428
429         if len(missing_sidset.symmetric_difference(extra_sidset)):
430             print("dn token sids unexpected")
431             print("tokengroups: %s" % dn_tokengroups)
432             print("user sids: %s" % self.user_sids)
433             print("actual difference: %s" % missing_sidset)
434             print("expected difference: %s" % extra_sidset)
435             print("unexpected difference : %s" %
436                     missing_sidset.symmetric_difference(extra_sidset))
437             self.fail(msg="DN tokenGroups unexpected difference to full user token")
438
439     def test_pac_groups(self):
440         if creds.get_kerberos_state() != MUST_USE_KERBEROS:
441             self.skipTest("Kerberos disabled, skipping PAC test")
442
443         settings = {}
444         settings["lp_ctx"] = lp
445         settings["target_hostname"] = lp.get("netbios name")
446
447         gensec_client = gensec.Security.start_client(settings)
448         gensec_client.set_credentials(self.get_creds(self.test_user, self.test_user_pass))
449         gensec_client.want_feature(gensec.FEATURE_SEAL)
450         gensec_client.start_mech_by_sasl_name("GSSAPI")
451
452         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
453
454         gensec_server = gensec.Security.start_server(settings, auth_context)
455         machine_creds = Credentials()
456         machine_creds.guess(lp)
457         machine_creds.set_machine_account(lp)
458         gensec_server.set_credentials(machine_creds)
459
460         gensec_server.want_feature(gensec.FEATURE_SEAL)
461         gensec_server.start_mech_by_sasl_name("GSSAPI")
462
463         client_finished = False
464         server_finished = False
465         server_to_client = b""
466
467         # Run the actual call loop.
468         while not client_finished and not server_finished:
469             if not client_finished:
470                 print("running client gensec_update")
471                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
472             if not server_finished:
473                 print("running server gensec_update")
474                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
475
476         session = gensec_server.session_info()
477
478         token = session.security_token
479         pac_sids = []
480         for s in token.sids:
481             pac_sids.append(str(s))
482
483         sidset1 = set(pac_sids)
484         sidset2 = set(self.user_sids)
485         if len(sidset1.symmetric_difference(sidset2)):
486             print("token sids don't match")
487             print("pac sids: %s" % pac_sids)
488             print("user sids : %s" % self.user_sids)
489             print("difference : %s" % sidset1.symmetric_difference(sidset2))
490             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
491
492     def test_tokenGroups_manual(self):
493         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
494         # and compare the result
495         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
496                                     expression="(|(objectclass=user)(objectclass=group))",
497                                     attrs=["memberOf"])
498         aSet = set()
499         aSetR = set()
500         vSet = set()
501         for obj in res:
502             if "memberOf" in obj:
503                 for dn in obj["memberOf"]:
504                     first = obj.dn.get_casefold()
505                     second = ldb.Dn(self.admin_ldb, dn.decode('utf8')).get_casefold()
506                     aSet.add((first, second))
507                     aSetR.add((second, first))
508                     vSet.add(first)
509                     vSet.add(second)
510
511         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
512                                     expression="(objectclass=user)",
513                                     attrs=["primaryGroupID"])
514         for obj in res:
515             if "primaryGroupID" in obj:
516                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
517                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
518                                              attrs=[])
519                 first = obj.dn.get_casefold()
520                 second = res2[0].dn.get_casefold()
521
522                 aSet.add((first, second))
523                 aSetR.add((second, first))
524                 vSet.add(first)
525                 vSet.add(second)
526
527         wSet = set()
528         wSet.add(self.test_user_dn.get_casefold())
529         closure(vSet, wSet, aSet)
530         wSet.remove(self.test_user_dn.get_casefold())
531
532         tokenGroupsSet = set()
533
534         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
535         self.assertEqual(len(res), 1)
536
537         for sid in res[0]['tokenGroups']:
538             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
539             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
540                                          attrs=[])
541             tokenGroupsSet.add(res3[0].dn.get_casefold())
542
543         if len(wSet.difference(tokenGroupsSet)):
544             self.fail(msg="additional calculated: %s" % wSet.difference(tokenGroupsSet))
545
546         if len(tokenGroupsSet.difference(wSet)):
547             self.fail(msg="additional tokenGroups: %s" % tokenGroupsSet.difference(wSet))
548
549     def filtered_closure(self, wSet, filter_grouptype):
550         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
551                                     expression="(|(objectclass=user)(objectclass=group))",
552                                     attrs=["memberOf"])
553         aSet = set()
554         aSetR = set()
555         vSet = set()
556         for obj in res:
557             vSet.add(obj.dn.get_casefold())
558             if "memberOf" in obj:
559                 for dn in obj["memberOf"]:
560                     first = obj.dn.get_casefold()
561                     second = ldb.Dn(self.admin_ldb, dn.decode('utf8')).get_casefold()
562                     aSet.add((first, second))
563                     aSetR.add((second, first))
564                     vSet.add(first)
565                     vSet.add(second)
566
567         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
568                                     expression="(objectclass=user)",
569                                     attrs=["primaryGroupID"])
570         for obj in res:
571             if "primaryGroupID" in obj:
572                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
573                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
574                                              attrs=[])
575                 first = obj.dn.get_casefold()
576                 second = res2[0].dn.get_casefold()
577
578                 aSet.add((first, second))
579                 aSetR.add((second, first))
580                 vSet.add(first)
581                 vSet.add(second)
582
583         uSet = set()
584         for v in vSet:
585             res_group = self.admin_ldb.search(base=v, scope=ldb.SCOPE_BASE,
586                                               attrs=["groupType"],
587                                               expression="objectClass=group")
588             if len(res_group) == 1:
589                 if hex(int(res_group[0]["groupType"][0]) & 0x00000000FFFFFFFF) == hex(filter_grouptype):
590                     uSet.add(v)
591             else:
592                 uSet.add(v)
593
594         closure(uSet, wSet, aSet)
595
596     def test_tokenGroupsGlobalAndUniversal_manual(self):
597         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
598         # and compare the result
599
600         # The variable names come from MS-ADTS May 15, 2014
601
602         S = set()
603         S.add(self.test_user_dn.get_casefold())
604
605         self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
606
607         T = set()
608         # Not really a SID, we do this on DNs...
609         for sid in S:
610             X = set()
611             X.add(sid)
612             self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)
613
614             T = T.union(X)
615
616         T.remove(self.test_user_dn.get_casefold())
617
618         tokenGroupsSet = set()
619
620         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
621         self.assertEqual(len(res), 1)
622
623         for sid in res[0]['tokenGroupsGlobalAndUniversal']:
624             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
625             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
626                                          attrs=[])
627             tokenGroupsSet.add(res3[0].dn.get_casefold())
628
629         if len(T.difference(tokenGroupsSet)):
630             self.fail(msg="additional calculated: %s" % T.difference(tokenGroupsSet))
631
632         if len(tokenGroupsSet.difference(T)):
633             self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" % tokenGroupsSet.difference(T))
634
635     def test_samr_GetGroupsForUser(self):
636         # Confirm that we get the correct results against SAMR also
637         if not url.startswith("ldap://"):
638             self.fail(msg="This test is only valid on ldap (so we an find the hostname and use SAMR)")
639         host = url.split("://")[1]
640         (domain_sid, user_rid) = self.user_sid.split()
641         samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, creds)
642         samr_handle = samr_conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
643         samr_domain = samr_conn.OpenDomain(samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
644                                            domain_sid)
645         user_handle = samr_conn.OpenUser(samr_domain, security.SEC_FLAG_MAXIMUM_ALLOWED, user_rid)
646         rids = samr_conn.GetGroupsForUser(user_handle)
647         samr_dns = set()
648         for rid in rids.rids:
649             self.assertEqual(rid.attributes, security.SE_GROUP_DEFAULT_FLAGS)
650             sid = "%s-%d" % (domain_sid, rid.rid)
651             res = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
652                                         attrs=[])
653             samr_dns.add(res[0].dn.get_casefold())
654
655         user_info = samr_conn.QueryUserInfo(user_handle, 1)
656         self.assertEqual(rids.rids[0].rid, user_info.primary_gid)
657
658         tokenGroupsSet = set()
659         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
660         for sid in res[0]['tokenGroupsGlobalAndUniversal']:
661             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
662             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
663                                          attrs=[],
664                                          expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
665                                          % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
666             if len(res) == 1:
667                 tokenGroupsSet.add(res3[0].dn.get_casefold())
668
669         if len(samr_dns.difference(tokenGroupsSet)):
670             self.fail(msg="additional samr_GetUserGroups over tokenGroups: %s" % samr_dns.difference(tokenGroupsSet))
671
672         memberOf = set()
673         # Add the primary group
674         primary_group_sid = "%s-%d" % (domain_sid, user_info.primary_gid)
675         res2 = self.admin_ldb.search(base="<SID=%s>" % primary_group_sid, scope=ldb.SCOPE_BASE,
676                                      attrs=[])
677
678         memberOf.add(res2[0].dn.get_casefold())
679         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["memberOf"])
680         for dn in res[0]['memberOf']:
681             res3 = self.admin_ldb.search(base=dn, scope=ldb.SCOPE_BASE,
682                                          attrs=[],
683                                          expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
684                                          % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
685             if len(res3) == 1:
686                 memberOf.add(res3[0].dn.get_casefold())
687
688         if len(memberOf.difference(samr_dns)):
689             self.fail(msg="additional memberOf over samr_GetUserGroups: %s" % memberOf.difference(samr_dns))
690
691         if len(samr_dns.difference(memberOf)):
692             self.fail(msg="additional samr_GetUserGroups over memberOf: %s" % samr_dns.difference(memberOf))
693
694         S = set()
695         S.add(self.test_user_dn.get_casefold())
696
697         self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
698         self.filtered_closure(S, GTYPE_SECURITY_UNIVERSAL_GROUP)
699
700         # Now remove the user DN and primary group
701         S.remove(self.test_user_dn.get_casefold())
702
703         if len(samr_dns.difference(S)):
704             self.fail(msg="additional samr_GetUserGroups over filtered_closure: %s" % samr_dns.difference(S))
705
706     def test_samr_GetGroupsForUser_nomember(self):
707         # Confirm that we get the correct results against SAMR also
708         if not url.startswith("ldap://"):
709             self.fail(msg="This test is only valid on ldap (so we an find the hostname and use SAMR)")
710         host = url.split("://")[1]
711
712         test_user = "tokengroups_user2"
713         self.admin_ldb.newuser(test_user, self.test_user_pass)
714         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (test_user, self.base_dn),
715                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
716         user_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
717
718         (domain_sid, user_rid) = user_sid.split()
719         samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, creds)
720         samr_handle = samr_conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
721         samr_domain = samr_conn.OpenDomain(samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
722                                            domain_sid)
723         user_handle = samr_conn.OpenUser(samr_domain, security.SEC_FLAG_MAXIMUM_ALLOWED, user_rid)
724         rids = samr_conn.GetGroupsForUser(user_handle)
725         user_info = samr_conn.QueryUserInfo(user_handle, 1)
726         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
727                      (test_user, "cn=users", self.base_dn))
728         self.assertEqual(len(rids.rids), 1)
729         self.assertEqual(rids.rids[0].rid, user_info.primary_gid)
730
731
732 if "://" not in url:
733     if os.path.isfile(url):
734         url = "tdb://%s" % url
735     else:
736         url = "ldap://%s" % url
737
738 TestProgram(module=__name__, opts=subunitopts)