s4-python Ensure we add the Samba python path first.
[samba.git] / source4 / dsdb / tests / python / token_group.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # test tokengroups attribute against internal token calculation
4
5 import optparse
6 import sys
7 import os
8
9 sys.path.insert(0, "bin/python")
10 import samba
11 samba.ensure_external_module("testtools", "testtools")
12 samba.ensure_external_module("subunit", "subunit/python")
13
14 import samba.getopt as options
15
16 from samba.auth import system_session
17 from samba import ldb
18 from samba.samdb import SamDB
19 from samba.auth import AuthContext
20 from samba.ndr import ndr_pack, ndr_unpack
21 from samba import gensec
22 from samba.credentials import Credentials
23
24 from subunit.run import SubunitTestRunner
25 import unittest
26 import samba.tests
27
28 from samba.dcerpc import security
29 from samba.auth import AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES
30
31
32 parser = optparse.OptionParser("ldap.py [options] <host>")
33 sambaopts = options.SambaOptions(parser)
34 parser.add_option_group(sambaopts)
35 parser.add_option_group(options.VersionOptions(parser))
36 # use command line creds if available
37 credopts = options.CredentialsOptions(parser)
38 parser.add_option_group(credopts)
39 opts, args = parser.parse_args()
40
41 if len(args) < 1:
42     parser.print_usage()
43     sys.exit(1)
44
45 url = args[0]
46
47 lp = sambaopts.get_loadparm()
48 creds = credopts.get_credentials(lp)
49
50 class TokenTest(samba.tests.TestCase):
51
52     def setUp(self):
53         super(TokenTest, self).setUp()
54         self.ldb = samdb
55         self.base_dn = samdb.domain_dn()
56
57         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
58         self.assertEquals(len(res), 1)
59
60         self.user_sid_dn = "<SID=%s>" % str(ndr_unpack(samba.dcerpc.security.dom_sid, res[0]["tokenGroups"][0]))
61
62         session_info_flags = ( AUTH_SESSION_INFO_DEFAULT_GROUPS |
63                                AUTH_SESSION_INFO_AUTHENTICATED |
64                                AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
65         session = samba.auth.user_session(self.ldb, lp_ctx=lp, dn=self.user_sid_dn,
66                                           session_info_flags=session_info_flags)
67
68         token = session.security_token
69         self.user_sids = []
70         for s in token.sids:
71             self.user_sids.append(str(s))
72
73     def test_rootDSE_tokenGroups(self):
74         """Testing rootDSE tokengroups against internal calculation"""
75         if not url.startswith("ldap"):
76             self.fail(msg="This test is only valid on ldap")
77
78         res = self.ldb.search("", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
79         self.assertEquals(len(res), 1)
80
81         print("Geting tokenGroups from rootDSE")
82         tokengroups = []
83         for sid in res[0]['tokenGroups']:
84             tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
85
86         sidset1 = set(tokengroups)
87         sidset2 = set(self.user_sids)
88         if len(sidset1.difference(sidset2)):
89             print("token sids don't match")
90             print("tokengroups: %s" % tokengroups)
91             print("calculated : %s" % self.user_sids);
92             print("difference : %s" % sidset1.difference(sidset2))
93             self.fail(msg="calculated groups don't match against rootDSE tokenGroups")
94
95     def test_dn_tokenGroups(self):
96         print("Geting tokenGroups from user DN")
97         res = self.ldb.search(self.user_sid_dn, scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
98         self.assertEquals(len(res), 1)
99
100         dn_tokengroups = []
101         for sid in res[0]['tokenGroups']:
102             dn_tokengroups.append(str(ndr_unpack(samba.dcerpc.security.dom_sid, sid)))
103
104         sidset1 = set(dn_tokengroups)
105         sidset2 = set(self.user_sids)
106         if len(sidset1.difference(sidset2)):
107             print("token sids don't match")
108             print("tokengroups: %s" % tokengroups)
109             print("calculated : %s" % sids);
110             print("difference : %s" % sidset1.difference(sidset2))
111             self.fail(msg="calculated groups don't match against user DN tokenGroups")
112         
113     def test_pac_groups(self):
114         settings = {}
115         settings["lp_ctx"] = lp
116         settings["target_hostname"] = lp.get("netbios name")
117
118         gensec_client = gensec.Security.start_client(settings)
119         gensec_client.set_credentials(creds)
120         gensec_client.want_feature(gensec.FEATURE_SEAL)
121         gensec_client.start_mech_by_sasl_name("GSSAPI")
122
123         auth_context = AuthContext(lp_ctx=lp, ldb=self.ldb, methods=[])
124
125         gensec_server = gensec.Security.start_server(settings, auth_context)
126         machine_creds = Credentials()
127         machine_creds.guess(lp)
128         machine_creds.set_machine_account(lp)
129         gensec_server.set_credentials(machine_creds)
130
131         gensec_server.want_feature(gensec.FEATURE_SEAL)
132         gensec_server.start_mech_by_sasl_name("GSSAPI")
133
134         client_finished = False
135         server_finished = False
136         server_to_client = ""
137         
138         """Run the actual call loop"""
139         while client_finished == False and server_finished == False:
140             if not client_finished:
141                 print "running client gensec_update"
142                 (client_finished, client_to_server) = gensec_client.update(server_to_client)
143             if not server_finished:
144                 print "running server gensec_update"
145                 (server_finished, server_to_client) = gensec_server.update(client_to_server)
146
147         session = gensec_server.session_info()
148
149         token = session.security_token
150         pac_sids = []
151         for s in token.sids:
152             pac_sids.append(str(s))
153
154         sidset1 = set(pac_sids)
155         sidset2 = set(self.user_sids)
156         if len(sidset1.difference(sidset2)):
157             print("token sids don't match")
158             print("tokengroups: %s" % tokengroups)
159             print("calculated : %s" % sids);
160             print("difference : %s" % sidset1.difference(sidset2))
161             self.fail(msg="calculated groups don't match against user PAC tokenGroups")
162
163
164 if not "://" in url:
165     if os.path.isfile(url):
166         url = "tdb://%s" % url
167     else:
168         url = "ldap://%s" % url
169
170 samdb = SamDB(url, credentials=creds, session_info=system_session(lp), lp=lp)
171
172 runner = SubunitTestRunner()
173 rc = 0
174 if not runner.run(unittest.makeSuite(TokenTest)).wasSuccessful():
175     rc = 1
176 sys.exit(rc)