PEP8: fix E201: whitespace after '('
[samba.git] / source4 / dsdb / tests / python / token_group.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # test tokengroups attribute against internal token calculation
4
5 from __future__ import print_function
6 import optparse
7 import sys
8 import os
9
10 sys.path.insert(0, "bin/python")
11 import samba
12
13 from samba.tests.subunitrun import SubunitOptions, TestProgram
14
15 import samba.getopt as options
16
17 from samba.auth import system_session
18 from samba import ldb, dsdb
19 from samba.samdb import SamDB
20 from samba.auth import AuthContext
21 from samba.ndr import ndr_unpack
22 from samba import gensec
23 from samba.credentials import Credentials, DONT_USE_KERBEROS
24 from samba.dsdb import GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP
25 import samba.tests
26 from samba.tests import delete_force
27 from samba.dcerpc import samr, security
28 from samba.auth import AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES, AUTH_SESSION_INFO_NTLM
29
30
31 parser = optparse.OptionParser("token_group.py [options] <host>")
32 sambaopts = options.SambaOptions(parser)
33 parser.add_option_group(sambaopts)
34 parser.add_option_group(options.VersionOptions(parser))
35 # use command line creds if available
36 credopts = options.CredentialsOptions(parser)
37 parser.add_option_group(credopts)
38 subunitopts = SubunitOptions(parser)
39 parser.add_option_group(subunitopts)
40 opts, args = parser.parse_args()
41
42 if len(args) < 1:
43     parser.print_usage()
44     sys.exit(1)
45
46 url = args[0]
47
48 lp = sambaopts.get_loadparm()
49 creds = credopts.get_credentials(lp)
50 creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
51
52 def closure(vSet, wSet, aSet):
53     for edge in aSet:
54         start, end = edge
55         if start in wSet:
56             if end not in wSet and end in vSet:
57                 wSet.add(end)
58                 closure(vSet, wSet, aSet)
59
60 class StaticTokenTest(samba.tests.TestCase):
61
62     def setUp(self):
63         super(StaticTokenTest, self).setUp()
64         self.ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
65         self.base_dn = self.ldb.domain_dn()
66
67         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
68         self.assertEquals(len(res), 1)
69
70         self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
71
72         session_info_flags = (AUTH_SESSION_INFO_DEFAULT_GROUPS |
73                                AUTH_SESSION_INFO_AUTHENTICATED |
74                                AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
75         if creds.get_kerberos_state() == DONT_USE_KERBEROS:
76             session_info_flags |= AUTH_SESSION_INFO_NTLM
77
78         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
79                                           session_info_flags=session_info_flags)
80
81         token = session.security_token
82         self.user_sids = []
83         for s in token.sids:
84             self.user_sids.append(str(s))
85
86     def test_rootDSE_tokenGroups(self):
87         """Testing rootDSE tokengroups against internal calculation"""
88         if not url.startswith("ldap"):
89             self.fail(msg="This test is only valid on ldap")
90
91         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
92         self.assertEquals(len(res), 1)
93
94         print("Getting tokenGroups from rootDSE")
95         tokengroups = []
96         for sid in res[0]['tokenGroups']:
97             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
98
99         sidset1 = set(tokengroups)
100         sidset2 = set(self.user_sids)
101         if len(sidset1.difference(sidset2)):
102             print("token sids don't match")
103             print("tokengroups: %s" % tokengroups)
104             print("calculated : %s" % self.user_sids)
105             print("difference : %s" % sidset1.difference(sidset2))
106             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
107
108     def test_dn_tokenGroups(self):
109         print("Getting tokenGroups from user DN")
110         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
111         self.assertEquals(len(res), 1)
112
113         dn_tokengroups = []
114         for sid in res[0]['tokenGroups']:
115             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
116
117         sidset1 = set(dn_tokengroups)
118         sidset2 = set(self.user_sids)
119         if len(sidset1.difference(sidset2)):
120             print("token sids don't match")
121             print("difference : %s" % sidset1.difference(sidset2))
122             self.fail(msg="calculated groups don't match against user DN tokenGroups")
123
124     def test_pac_groups(self):
125         if creds.get_kerberos_state() == DONT_USE_KERBEROS:
126             self.skipTest("Kerberos disabled, skipping PAC test")
127
128         settings = {}
129         settings["lp_ctx"] = lp
130         settings["target_hostname"] = lp.get("netbios name")
131
132         gensec_client = gensec.Security.start_client(settings)
133         gensec_client.set_credentials(creds)
134         gensec_client.want_feature(gensec.FEATURE_SEAL)
135         gensec_client.start_mech_by_sasl_name("GSSAPI")
136
137         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
138
139         gensec_server = gensec.Security.start_server(settings, auth_context)
140         machine_creds = Credentials()
141         machine_creds.guess(lp)
142         machine_creds.set_machine_account(lp)
143         gensec_server.set_credentials(machine_creds)
144
145         gensec_server.want_feature(gensec.FEATURE_SEAL)
146         gensec_server.start_mech_by_sasl_name("GSSAPI")
147
148         client_finished = False
149         server_finished = False
150         server_to_client = ""
151
152         # Run the actual call loop.
153         while client_finished == False and server_finished == False:
154             if not client_finished:
155                 print("running client gensec_update")
156                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
157             if not server_finished:
158                 print("running server gensec_update")
159                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
160
161         session = gensec_server.session_info()
162
163         token = session.security_token
164         pac_sids = []
165         for s in token.sids:
166             pac_sids.append(str(s))
167
168         sidset1 = set(pac_sids)
169         sidset2 = set(self.user_sids)
170         if len(sidset1.difference(sidset2)):
171             print("token sids don't match")
172             print("difference : %s" % sidset1.difference(sidset2))
173             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
174
175 class DynamicTokenTest(samba.tests.TestCase):
176
177     def get_creds(self, target_username, target_password):
178         creds_tmp = Credentials()
179         creds_tmp.set_username(target_username)
180         creds_tmp.set_password(target_password)
181         creds_tmp.set_domain(creds.get_domain())
182         creds_tmp.set_realm(creds.get_realm())
183         creds_tmp.set_workstation(creds.get_workstation())
184         creds_tmp.set_gensec_features(creds_tmp.get_gensec_features()
185                                       | gensec.FEATURE_SEAL)
186         return creds_tmp
187
188     def get_ldb_connection(self, target_username, target_password):
189         creds_tmp = self.get_creds(target_username, target_password)
190         ldb_target = SamDB(url=url, credentials=creds_tmp, lp=lp)
191         return ldb_target
192
193     def setUp(self):
194         super(DynamicTokenTest, self).setUp()
195         self.admin_ldb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
196
197         self.base_dn = self.admin_ldb.domain_dn()
198
199         self.test_user = "tokengroups_user1"
200         self.test_user_pass = "samba123@"
201         self.admin_ldb.newuser(self.test_user, self.test_user_pass)
202         self.test_group0 = "tokengroups_group0"
203         self.admin_ldb.newgroup(self.test_group0, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
204         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group0, self.base_dn),
205                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
206         self.test_group0_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
207
208         self.admin_ldb.add_remove_group_members(self.test_group0, [self.test_user],
209                                                 add_members_operation=True)
210
211         self.test_group1 = "tokengroups_group1"
212         self.admin_ldb.newgroup(self.test_group1, grouptype=dsdb.GTYPE_SECURITY_GLOBAL_GROUP)
213         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group1, self.base_dn),
214                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
215         self.test_group1_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
216
217         self.admin_ldb.add_remove_group_members(self.test_group1, [self.test_user],
218                                                 add_members_operation=True)
219
220         self.test_group2 = "tokengroups_group2"
221         self.admin_ldb.newgroup(self.test_group2, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
222
223         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group2, self.base_dn),
224                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
225         self.test_group2_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
226
227         self.admin_ldb.add_remove_group_members(self.test_group2, [self.test_user],
228                                                 add_members_operation=True)
229
230         self.test_group3 = "tokengroups_group3"
231         self.admin_ldb.newgroup(self.test_group3, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
232
233         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group3, self.base_dn),
234                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
235         self.test_group3_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
236
237         self.admin_ldb.add_remove_group_members(self.test_group3, [self.test_group1],
238                                                 add_members_operation=True)
239
240         self.test_group4 = "tokengroups_group4"
241         self.admin_ldb.newgroup(self.test_group4, grouptype=dsdb.GTYPE_SECURITY_UNIVERSAL_GROUP)
242
243         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group4, self.base_dn),
244                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
245         self.test_group4_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
246
247         self.admin_ldb.add_remove_group_members(self.test_group4, [self.test_group3],
248                                                 add_members_operation=True)
249
250         self.test_group5 = "tokengroups_group5"
251         self.admin_ldb.newgroup(self.test_group5, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
252
253         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group5, self.base_dn),
254                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
255         self.test_group5_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
256
257         self.admin_ldb.add_remove_group_members(self.test_group5, [self.test_group4],
258                                                 add_members_operation=True)
259
260         self.test_group6 = "tokengroups_group6"
261         self.admin_ldb.newgroup(self.test_group6, grouptype=dsdb.GTYPE_SECURITY_DOMAIN_LOCAL_GROUP)
262
263         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (self.test_group6, self.base_dn),
264                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
265         self.test_group6_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
266
267         self.admin_ldb.add_remove_group_members(self.test_group6, [self.test_user],
268                                                 add_members_operation=True)
269
270         self.ldb = self.get_ldb_connection(self.test_user, self.test_user_pass)
271
272         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
273         self.assertEquals(len(res), 1)
274
275         self.user_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0])
276         self.user_sid_dn = "<SID=%s>" % str(self.user_sid)
277
278         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=[])
279         self.assertEquals(len(res), 1)
280
281         self.test_user_dn = res[0].dn
282
283         session_info_flags = (AUTH_SESSION_INFO_DEFAULT_GROUPS |
284                                AUTH_SESSION_INFO_AUTHENTICATED |
285                                AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
286
287         if creds.get_kerberos_state() == DONT_USE_KERBEROS:
288             session_info_flags |= AUTH_SESSION_INFO_NTLM
289
290         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
291                                           session_info_flags=session_info_flags)
292
293         token = session.security_token
294         self.user_sids = []
295         for s in token.sids:
296             self.user_sids.append(str(s))
297
298     def tearDown(self):
299         super(DynamicTokenTest, self).tearDown()
300         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
301                      (self.test_user, "cn=users", self.base_dn))
302         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
303                      (self.test_group0, "cn=users", self.base_dn))
304         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
305                      (self.test_group1, "cn=users", self.base_dn))
306         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
307                      (self.test_group2, "cn=users", self.base_dn))
308         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
309                      (self.test_group3, "cn=users", self.base_dn))
310         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
311                      (self.test_group4, "cn=users", self.base_dn))
312         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
313                      (self.test_group5, "cn=users", self.base_dn))
314         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
315                      (self.test_group6, "cn=users", self.base_dn))
316
317     def test_rootDSE_tokenGroups(self):
318         """Testing rootDSE tokengroups against internal calculation"""
319         if not url.startswith("ldap"):
320             self.fail(msg="This test is only valid on ldap")
321
322         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
323         self.assertEquals(len(res), 1)
324
325         print("Getting tokenGroups from rootDSE")
326         tokengroups = []
327         for sid in res[0]['tokenGroups']:
328             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
329
330         sidset1 = set(tokengroups)
331         sidset2 = set(self.user_sids)
332         if len(sidset1.difference(sidset2)):
333             print("token sids don't match")
334             print("tokengroups: %s" % tokengroups)
335             print("calculated : %s" % self.user_sids)
336             print("difference : %s" % sidset1.difference(sidset2))
337             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
338
339     def test_dn_tokenGroups(self):
340         print("Getting tokenGroups from user DN")
341         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
342         self.assertEquals(len(res), 1)
343
344         dn_tokengroups = []
345         for sid in res[0]['tokenGroups']:
346             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
347
348         sidset1 = set(dn_tokengroups)
349         sidset2 = set(self.user_sids)
350
351         # The SIDs on the DN do not include the NTLM authentication SID
352         sidset2.discard(samba.dcerpc.security.SID_NT_NTLM_AUTHENTICATION)
353
354         if len(sidset1.difference(sidset2)):
355             print("token sids don't match")
356             print("difference : %s" % sidset1.difference(sidset2))
357             self.fail(msg="calculated groups don't match against user DN tokenGroups")
358
359     def test_pac_groups(self):
360         settings = {}
361         settings["lp_ctx"] = lp
362         settings["target_hostname"] = lp.get("netbios name")
363
364         gensec_client = gensec.Security.start_client(settings)
365         gensec_client.set_credentials(self.get_creds(self.test_user, self.test_user_pass))
366         gensec_client.want_feature(gensec.FEATURE_SEAL)
367         gensec_client.start_mech_by_sasl_name("GSSAPI")
368
369         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
370
371         gensec_server = gensec.Security.start_server(settings, auth_context)
372         machine_creds = Credentials()
373         machine_creds.guess(lp)
374         machine_creds.set_machine_account(lp)
375         gensec_server.set_credentials(machine_creds)
376
377         gensec_server.want_feature(gensec.FEATURE_SEAL)
378         gensec_server.start_mech_by_sasl_name("GSSAPI")
379
380         client_finished = False
381         server_finished = False
382         server_to_client = ""
383
384         # Run the actual call loop.
385         while client_finished == False and server_finished == False:
386             if not client_finished:
387                 print("running client gensec_update")
388                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
389             if not server_finished:
390                 print("running server gensec_update")
391                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
392
393         session = gensec_server.session_info()
394
395         token = session.security_token
396         pac_sids = []
397         for s in token.sids:
398             pac_sids.append(str(s))
399
400         sidset1 = set(pac_sids)
401         sidset2 = set(self.user_sids)
402         if len(sidset1.difference(sidset2)):
403             print("token sids don't match")
404             print("difference : %s" % sidset1.difference(sidset2))
405             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
406
407
408     def test_tokenGroups_manual(self):
409         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
410         # and compare the result
411         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
412                                     expression="(|(objectclass=user)(objectclass=group))",
413                                     attrs=["memberOf"])
414         aSet = set()
415         aSetR = set()
416         vSet = set()
417         for obj in res:
418             if "memberOf" in obj:
419                 for dn in obj["memberOf"]:
420                     first = obj.dn.get_casefold()
421                     second = ldb.Dn(self.admin_ldb, dn.decode('utf8')).get_casefold()
422                     aSet.add((first, second))
423                     aSetR.add((second, first))
424                     vSet.add(first)
425                     vSet.add(second)
426
427         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
428                                     expression="(objectclass=user)",
429                                     attrs=["primaryGroupID"])
430         for obj in res:
431             if "primaryGroupID" in obj:
432                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
433                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
434                                              attrs=[])
435                 first = obj.dn.get_casefold()
436                 second = res2[0].dn.get_casefold()
437
438                 aSet.add((first, second))
439                 aSetR.add((second, first))
440                 vSet.add(first)
441                 vSet.add(second)
442
443         wSet = set()
444         wSet.add(self.test_user_dn.get_casefold())
445         closure(vSet, wSet, aSet)
446         wSet.remove(self.test_user_dn.get_casefold())
447
448         tokenGroupsSet = set()
449
450         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
451         self.assertEquals(len(res), 1)
452
453         dn_tokengroups = []
454         for sid in res[0]['tokenGroups']:
455             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
456             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
457                                          attrs=[])
458             tokenGroupsSet.add(res3[0].dn.get_casefold())
459
460         if len(wSet.difference(tokenGroupsSet)):
461             self.fail(msg="additional calculated: %s" % wSet.difference(tokenGroupsSet))
462
463         if len(tokenGroupsSet.difference(wSet)):
464             self.fail(msg="additional tokenGroups: %s" % tokenGroupsSet.difference(wSet))
465
466
467     def filtered_closure(self, wSet, filter_grouptype):
468         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
469                                     expression="(|(objectclass=user)(objectclass=group))",
470                                     attrs=["memberOf"])
471         aSet = set()
472         aSetR = set()
473         vSet = set()
474         for obj in res:
475             vSet.add(obj.dn.get_casefold())
476             if "memberOf" in obj:
477                 for dn in obj["memberOf"]:
478                     first = obj.dn.get_casefold()
479                     second = ldb.Dn(self.admin_ldb, dn.decode('utf8')).get_casefold()
480                     aSet.add((first, second))
481                     aSetR.add((second, first))
482                     vSet.add(first)
483                     vSet.add(second)
484
485         res = self.admin_ldb.search(base=self.base_dn, scope=ldb.SCOPE_SUBTREE,
486                                     expression="(objectclass=user)",
487                                     attrs=["primaryGroupID"])
488         for obj in res:
489             if "primaryGroupID" in obj:
490                 sid = "%s-%d" % (self.admin_ldb.get_domain_sid(), int(obj["primaryGroupID"][0]))
491                 res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
492                                              attrs=[])
493                 first = obj.dn.get_casefold()
494                 second = res2[0].dn.get_casefold()
495
496                 aSet.add((first, second))
497                 aSetR.add((second, first))
498                 vSet.add(first)
499                 vSet.add(second)
500
501         uSet = set()
502         for v in vSet:
503             res_group = self.admin_ldb.search(base=v, scope=ldb.SCOPE_BASE,
504                                               attrs=["groupType"],
505                                               expression="objectClass=group")
506             if len(res_group) == 1:
507                 if hex(int(res_group[0]["groupType"][0]) & 0x00000000FFFFFFFF) == hex(filter_grouptype):
508                     uSet.add(v)
509             else:
510                 uSet.add(v)
511
512         closure(uSet, wSet, aSet)
513
514
515     def test_tokenGroupsGlobalAndUniversal_manual(self):
516         # Manually run the tokenGroups algorithm from MS-ADTS 3.1.1.4.5.19 and MS-DRSR 4.1.8.3
517         # and compare the result
518
519         # The variable names come from MS-ADTS May 15, 2014
520
521         S = set()
522         S.add(self.test_user_dn.get_casefold())
523
524         self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
525
526         T = set()
527         # Not really a SID, we do this on DNs...
528         for sid in S:
529             X = set()
530             X.add(sid)
531             self.filtered_closure(X, GTYPE_SECURITY_UNIVERSAL_GROUP)
532
533             T = T.union(X)
534
535         T.remove(self.test_user_dn.get_casefold())
536
537         tokenGroupsSet = set()
538
539         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
540         self.assertEquals(len(res), 1)
541
542         dn_tokengroups = []
543         for sid in res[0]['tokenGroupsGlobalAndUniversal']:
544             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
545             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
546                                          attrs=[])
547             tokenGroupsSet.add(res3[0].dn.get_casefold())
548
549         if len(T.difference(tokenGroupsSet)):
550             self.fail(msg="additional calculated: %s" % T.difference(tokenGroupsSet))
551
552         if len(tokenGroupsSet.difference(T)):
553             self.fail(msg="additional tokenGroupsGlobalAndUniversal: %s" % tokenGroupsSet.difference(T))
554
555     def test_samr_GetGroupsForUser(self):
556         # Confirm that we get the correct results against SAMR also
557         if not url.startswith("ldap://"):
558             self.fail(msg="This test is only valid on ldap (so we an find the hostname and use SAMR)")
559         host = url.split("://")[1]
560         (domain_sid, user_rid) = self.user_sid.split()
561         samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, creds)
562         samr_handle = samr_conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
563         samr_domain = samr_conn.OpenDomain(samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
564                                            domain_sid)
565         user_handle = samr_conn.OpenUser(samr_domain, security.SEC_FLAG_MAXIMUM_ALLOWED, user_rid)
566         rids = samr_conn.GetGroupsForUser(user_handle)
567         samr_dns = set()
568         for rid in rids.rids:
569             self.assertEqual(rid.attributes, samr.SE_GROUP_MANDATORY | samr.SE_GROUP_ENABLED_BY_DEFAULT| samr.SE_GROUP_ENABLED)
570             sid = "%s-%d" % (domain_sid, rid.rid)
571             res = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
572                                         attrs=[])
573             samr_dns.add(res[0].dn.get_casefold())
574
575         user_info = samr_conn.QueryUserInfo(user_handle, 1)
576         self.assertEqual(rids.rids[0].rid, user_info.primary_gid)
577
578         tokenGroupsSet = set()
579         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroupsGlobalAndUniversal"])
580         for sid in res[0]['tokenGroupsGlobalAndUniversal']:
581             sid = ndr_unpack(samba.dcerpc.security.dom_sid, sid)
582             res3 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
583                                          attrs=[],
584                                          expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
585                                          % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
586             if len(res) == 1:
587                 tokenGroupsSet.add(res3[0].dn.get_casefold())
588
589         if len(samr_dns.difference(tokenGroupsSet)):
590             self.fail(msg="additional samr_GetUserGroups over tokenGroups: %s" % samr_dns.difference(tokenGroupsSet))
591
592         memberOf = set()
593         # Add the primary group
594         primary_group_sid = "%s-%d" % (domain_sid, user_info.primary_gid)
595         res2 = self.admin_ldb.search(base="<SID=%s>" % sid, scope=ldb.SCOPE_BASE,
596                                      attrs=[])
597
598         memberOf.add(res2[0].dn.get_casefold())
599         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["memberOf"])
600         for dn in res[0]['memberOf']:
601             res3 = self.admin_ldb.search(base=dn, scope=ldb.SCOPE_BASE,
602                                          attrs=[],
603                                          expression="(&(|(grouptype=%d)(grouptype=%d))(objectclass=group))"
604                                          % (GTYPE_SECURITY_GLOBAL_GROUP, GTYPE_SECURITY_UNIVERSAL_GROUP))
605             if len(res3) == 1:
606                 memberOf.add(res3[0].dn.get_casefold())
607
608         if len(memberOf.difference(samr_dns)):
609             self.fail(msg="additional memberOf over samr_GetUserGroups: %s" % memberOf.difference(samr_dns))
610
611         if len(samr_dns.difference(memberOf)):
612             self.fail(msg="additional samr_GetUserGroups over memberOf: %s" % samr_dns.difference(memberOf))
613
614         S = set()
615         S.add(self.test_user_dn.get_casefold())
616
617         self.filtered_closure(S, GTYPE_SECURITY_GLOBAL_GROUP)
618         self.filtered_closure(S, GTYPE_SECURITY_UNIVERSAL_GROUP)
619
620         # Now remove the user DN and primary group
621         S.remove(self.test_user_dn.get_casefold())
622
623         if len(samr_dns.difference(S)):
624             self.fail(msg="additional samr_GetUserGroups over filtered_closure: %s" % samr_dns.difference(S))
625
626     def test_samr_GetGroupsForUser_nomember(self):
627         # Confirm that we get the correct results against SAMR also
628         if not url.startswith("ldap://"):
629             self.fail(msg="This test is only valid on ldap (so we an find the hostname and use SAMR)")
630         host = url.split("://")[1]
631
632         test_user = "tokengroups_user2"
633         self.admin_ldb.newuser(test_user, self.test_user_pass)
634         res = self.admin_ldb.search(base="cn=%s,cn=users,%s" % (test_user, self.base_dn),
635                                     attrs=["objectSid"], scope=ldb.SCOPE_BASE)
636         user_sid = ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["objectSid"][0])
637
638         (domain_sid, user_rid) = user_sid.split()
639         samr_conn = samba.dcerpc.samr.samr("ncacn_ip_tcp:%s[seal]" % host, lp, creds)
640         samr_handle = samr_conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
641         samr_domain = samr_conn.OpenDomain(samr_handle, security.SEC_FLAG_MAXIMUM_ALLOWED,
642                                            domain_sid)
643         user_handle = samr_conn.OpenUser(samr_domain, security.SEC_FLAG_MAXIMUM_ALLOWED, user_rid)
644         rids = samr_conn.GetGroupsForUser(user_handle)
645         user_info = samr_conn.QueryUserInfo(user_handle, 1)
646         delete_force(self.admin_ldb, "CN=%s,%s,%s" %
647                      (test_user, "cn=users", self.base_dn))
648         self.assertEqual(len(rids.rids), 1)
649         self.assertEqual(rids.rids[0].rid, user_info.primary_gid)
650
651 if not "://" in url:
652     if os.path.isfile(url):
653         url = "tdb://%s" % url
654     else:
655         url = "ldap://%s" % url
656
657 TestProgram(module=__name__, opts=subunitopts)