f709ba489508ba31f1f9c88918e93616933df5b4
[amitay/samba.git] / python / samba / tests / dcerpc / sam.py
1 # -*- coding: utf-8 -*-
2 #
3 # Unix SMB/CIFS implementation.
4 # Copyright © Jelmer Vernooij <jelmer@samba.org> 2008
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19
20 """Tests for samba.dcerpc.sam."""
21
22 from samba.dcerpc import samr, security, lsa
23 from samba.tests import RpcInterfaceTestCase
24 from samba.tests import env_loadparm, delete_force
25
26 from samba.credentials import Credentials
27 from samba.auth import system_session
28 from samba.samdb import SamDB
29 from samba.dsdb import (
30     ATYPE_NORMAL_ACCOUNT,
31     ATYPE_WORKSTATION_TRUST,
32     GTYPE_SECURITY_UNIVERSAL_GROUP,
33     GTYPE_SECURITY_GLOBAL_GROUP)
34 from samba import generate_random_password
35 import os
36
37
38 # FIXME: Pidl should be doing this for us
39 def toArray(handle, array, num_entries):
40     return [(entry.idx, entry.name) for entry in array.entries[:num_entries]]
41
42
43 class SamrTests(RpcInterfaceTestCase):
44
45     def setUp(self):
46         super(SamrTests, self).setUp()
47         self.conn = samr.samr("ncalrpc:", self.get_loadparm())
48         self.open_samdb()
49         self.open_domain_handle()
50
51     #
52     # Open the samba database
53     #
54     def open_samdb(self):
55         self.lp = env_loadparm()
56         self.domain = os.environ["DOMAIN"]
57         self.creds = Credentials()
58         self.creds.guess(self.lp)
59         self.session = system_session()
60         self.samdb = SamDB(
61             session_info=self.session, credentials=self.creds, lp=self.lp)
62
63     #
64     # Open a SAMR Domain handle
65     def open_domain_handle(self):
66         self.handle = self.conn.Connect2(
67             None, security.SEC_FLAG_MAXIMUM_ALLOWED)
68
69         self.domain_sid = self.conn.LookupDomain(
70             self.handle, lsa.String(self.domain))
71
72         self.domain_handle = self.conn.OpenDomain(
73             self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)
74
75     def test_connect5(self):
76         (level, info, handle) =\
77             self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
78
79     def test_connect2(self):
80         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
81         self.assertTrue(handle is not None)
82
83     def test_EnumDomains(self):
84         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
85         toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
86         self.conn.Close(handle)
87
88     # Create groups based on the id list supplied, the id is used to
89     # form a unique name and description.
90     #
91     # returns a list of the created dn's, which can be passed to delete_dns
92     # to clean up after the test has run.
93     def create_groups(self, ids):
94         dns = []
95         for i in ids:
96             name = "SAMR_GRP%d" % i
97             dn = "cn=%s,cn=Users,%s" % (name, self.samdb.domain_dn())
98             delete_force(self.samdb, dn)
99
100             self.samdb.newgroup(name)
101             dns.append(dn)
102         return dns
103
104     # Create user accounts based on the id list supplied, the id is used to
105     # form a unique name and description.
106     #
107     # returns a list of the created dn's, which can be passed to delete_dns
108     # to clean up after the test has run.
109     def create_users(self, ids):
110         dns = []
111         for i in ids:
112             name = "SAMR_USER%d" % i
113             dn = "cn=%s,CN=USERS,%s" % (name, self.samdb.domain_dn())
114             delete_force(self.samdb, dn)
115
116             # We only need the user to exist, we don't need a password
117             self.samdb.newuser(
118                 name,
119                 password=None,
120                 setpassword=False,
121                 description="Description for " + name,
122                 givenname="given%dname" % i,
123                 surname="surname%d" % i)
124             dns.append(dn)
125         return dns
126
127     # Create computer accounts based on the id list supplied, the id is used to
128     # form a unique name and description.
129     #
130     # returns a list of the created dn's, which can be passed to delete_dns
131     # to clean up after the test has run.
132     def create_computers(self, ids):
133         dns = []
134         for i in ids:
135             name = "SAMR_CMP%d" % i
136             dn = "cn=%s,cn=COMPUTERS,%s" % (name, self.samdb.domain_dn())
137             delete_force(self.samdb, dn)
138
139             self.samdb.newcomputer(name, description="Description of " + name)
140             dns.append(dn)
141         return dns
142
143     # Delete the specified dn's.
144     #
145     # Used to clean up entries created by individual tests.
146     #
147     def delete_dns(self, dns):
148         for dn in dns:
149             delete_force(self.samdb, dn)
150
151     # Common tests for QueryDisplayInfo
152     #
153     def _test_QueryDisplayInfo(
154             self, level, check_results, select, attributes, add_elements):
155         #
156         # Get the expected results by querying the samdb database directly.
157         # We do this rather than use a list of expected results as this runs
158         # with other tests so we do not have a known fixed list of elements
159         expected = self.samdb.search(expression=select, attrs=attributes)
160         self.assertTrue(len(expected) > 0)
161
162         #
163         # Perform QueryDisplayInfo with max results greater than the expected
164         # number of results.
165         (ts, rs, actual) = self.conn.QueryDisplayInfo(
166             self.domain_handle, level, 0, 1024, 4294967295)
167
168         self.assertEquals(len(expected), ts)
169         self.assertEquals(len(expected), rs)
170         check_results(expected, actual.entries)
171
172         #
173         # Perform QueryDisplayInfo with max results set to the number of
174         # results returned from the first query, should return the same results
175         (ts1, rs1, actual1) = self.conn.QueryDisplayInfo(
176             self.domain_handle, level, 0, rs, 4294967295)
177         self.assertEquals(ts, ts1)
178         self.assertEquals(rs, rs1)
179         check_results(expected, actual1.entries)
180
181         #
182         # Perform QueryDisplayInfo and get the last two results.
183         # Note: We are assuming there are at least three entries
184         self.assertTrue(ts > 2)
185         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
186             self.domain_handle, level, (ts - 2), 2, 4294967295)
187         self.assertEquals(ts, ts2)
188         self.assertEquals(2, rs2)
189         check_results(list(expected)[-2:], actual2.entries)
190
191         #
192         # Perform QueryDisplayInfo and get the first two results.
193         # Note: We are assuming there are at least three entries
194         self.assertTrue(ts > 2)
195         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
196             self.domain_handle, level, 0, 2, 4294967295)
197         self.assertEquals(ts, ts2)
198         self.assertEquals(2, rs2)
199         check_results(list(expected)[:2], actual2.entries)
200
201         #
202         # Perform QueryDisplayInfo and get two results in the middle of the
203         # list i.e. not the first or the last entry.
204         # Note: We are assuming there are at least four entries
205         self.assertTrue(ts > 3)
206         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
207             self.domain_handle, level, 1, 2, 4294967295)
208         self.assertEquals(ts, ts2)
209         self.assertEquals(2, rs2)
210         check_results(list(expected)[1:2], actual2.entries)
211
212         #
213         # To check that cached values are being returned rather than the
214         # results being re-read from disk we add elements, and request all
215         # but the first result.
216         #
217         dns = add_elements([1000, 1002, 1003, 1004])
218
219         #
220         # Perform QueryDisplayInfo and get all but the first result.
221         # We should be using the cached results so the entries we just added
222         # should not be present
223         (ts3, rs3, actual3) = self.conn.QueryDisplayInfo(
224             self.domain_handle, level, 1, 1024, 4294967295)
225         self.assertEquals(ts, ts3)
226         self.assertEquals(len(expected) - 1, rs3)
227         check_results(list(expected)[1:], actual3.entries)
228
229         #
230         # Perform QueryDisplayInfo and get all the results.
231         # As the start index is zero we should reread the data from disk and
232         # the added entries should be there
233         new = self.samdb.search(expression=select, attrs=attributes)
234         (ts4, rs4, actual4) = self.conn.QueryDisplayInfo(
235             self.domain_handle, level, 0, 1024, 4294967295)
236         self.assertEquals(len(expected) + len(dns), ts4)
237         self.assertEquals(len(expected) + len(dns), rs4)
238         check_results(new, actual4.entries)
239
240         # Delete the added DN's and query all but the first entry.
241         # This should ensure the cached results are used and that the
242         # missing entry code is triggered.
243         self.delete_dns(dns)
244         (ts5, rs5, actual5) = self.conn.QueryDisplayInfo(
245             self.domain_handle, level, 1, 1024, 4294967295)
246         self.assertEquals(len(expected) + len(dns), ts5)
247         # The deleted results will be filtered from the result set so should
248         # be missing from the returned results.
249         # Note: depending on the GUID order, the first result in the cache may
250         #       be a deleted entry, in which case the results will contain all
251         #       the expected elements, otherwise the first expected result will
252         #       be missing.
253         if rs5 == len(expected):
254             check_results(expected, actual5.entries)
255         elif rs5 == (len(expected) - 1):
256             check_results(list(expected)[1:], actual5.entries)
257         else:
258             self.fail("Incorrect number of entries {0}".format(rs5))
259
260         #
261         # Perform QueryDisplayInfo specifying an index past the end of the
262         # available data.
263         # Should return no data.
264         (ts6, rs6, actual6) = self.conn.QueryDisplayInfo(
265             self.domain_handle, level, ts5, 1, 4294967295)
266         self.assertEquals(ts5, ts6)
267         self.assertEquals(0, rs6)
268
269         self.conn.Close(self.handle)
270
271     # Test for QueryDisplayInfo, Level 1
272     # Returns the sAMAccountName, displayName and description for all
273     # the user accounts.
274     #
275     def test_QueryDisplayInfo_level_1(self):
276         def check_results(expected, actual):
277             # Assume the QueryDisplayInfo and ldb.search return their results
278             # in the same order
279             for (e, a) in zip(expected, actual):
280                 self.assertTrue(isinstance(a, samr.DispEntryGeneral))
281                 self.assertEquals(str(e["sAMAccountName"]),
282                                   str(a.account_name))
283
284                 # The displayName and description are optional.
285                 # In the expected results they will be missing, in
286                 # samr.DispEntryGeneral the corresponding attribute will have a
287                 # length of zero.
288                 #
289                 if a.full_name.length == 0:
290                     self.assertFalse("displayName" in e)
291                 else:
292                     self.assertEquals(str(e["displayName"]), str(a.full_name))
293
294                 if a.description.length == 0:
295                     self.assertFalse("description" in e)
296                 else:
297                     self.assertEquals(str(e["description"]),
298                                       str(a.description))
299         # Create four user accounts
300         # to ensure that we have the minimum needed for the tests.
301         dns = self.create_users([1, 2, 3, 4])
302
303         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
304             ATYPE_NORMAL_ACCOUNT)
305         attributes = ["sAMAccountName", "displayName", "description"]
306         self._test_QueryDisplayInfo(
307             1, check_results, select, attributes, self.create_users)
308
309         self.delete_dns(dns)
310
311     # Test for QueryDisplayInfo, Level 2
312     # Returns the sAMAccountName and description for all
313     # the computer accounts.
314     #
315     def test_QueryDisplayInfo_level_2(self):
316         def check_results(expected, actual):
317             # Assume the QueryDisplayInfo and ldb.search return their results
318             # in the same order
319             for (e, a) in zip(expected, actual):
320                 self.assertTrue(isinstance(a, samr.DispEntryFull))
321                 self.assertEquals(str(e["sAMAccountName"]),
322                                   str(a.account_name))
323
324                 # The description is optional.
325                 # In the expected results they will be missing, in
326                 # samr.DispEntryGeneral the corresponding attribute will have a
327                 # length of zero.
328                 #
329                 if a.description.length == 0:
330                     self.assertFalse("description" in e)
331                 else:
332                     self.assertEquals(str(e["description"]),
333                                       str(a.description))
334
335         # Create four computer accounts
336         # to ensure that we have the minimum needed for the tests.
337         dns = self.create_computers([1, 2, 3, 4])
338
339         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
340             ATYPE_WORKSTATION_TRUST)
341         attributes = ["sAMAccountName", "description"]
342         self._test_QueryDisplayInfo(
343             2, check_results, select, attributes, self.create_computers)
344
345         self.delete_dns(dns)
346
347     # Test for QueryDisplayInfo, Level 3
348     # Returns the sAMAccountName and description for all
349     # the groups.
350     #
351     def test_QueryDisplayInfo_level_3(self):
352         def check_results(expected, actual):
353             # Assume the QueryDisplayInfo and ldb.search return their results
354             # in the same order
355             for (e, a) in zip(expected, actual):
356                 self.assertTrue(isinstance(a, samr.DispEntryFullGroup))
357                 self.assertEquals(str(e["sAMAccountName"]),
358                                   str(a.account_name))
359
360                 # The description is optional.
361                 # In the expected results they will be missing, in
362                 # samr.DispEntryGeneral the corresponding attribute will have a
363                 # length of zero.
364                 #
365                 if a.description.length == 0:
366                     self.assertFalse("description" in e)
367                 else:
368                     self.assertEquals(str(e["description"]),
369                                       str(a.description))
370
371         # Create four groups
372         # to ensure that we have the minimum needed for the tests.
373         dns = self.create_groups([1, 2, 3, 4])
374
375         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
376             GTYPE_SECURITY_UNIVERSAL_GROUP,
377             GTYPE_SECURITY_GLOBAL_GROUP)
378         attributes = ["sAMAccountName", "description"]
379         self._test_QueryDisplayInfo(
380             3, check_results, select, attributes, self.create_groups)
381
382         self.delete_dns(dns)
383
384     # Test for QueryDisplayInfo, Level 4
385     # Returns the sAMAccountName (as an ASCII string)
386     # for all the user accounts.
387     #
388     def test_QueryDisplayInfo_level_4(self):
389         def check_results(expected, actual):
390             # Assume the QueryDisplayInfo and ldb.search return their results
391             # in the same order
392             for (e, a) in zip(expected, actual):
393                 self.assertTrue(isinstance(a, samr.DispEntryAscii))
394                 self.assertTrue(
395                     isinstance(a.account_name, lsa.AsciiStringLarge))
396                 self.assertEquals(
397                     str(e["sAMAccountName"]), str(a.account_name.string))
398
399         # Create four user accounts
400         # to ensure that we have the minimum needed for the tests.
401         dns = self.create_users([1, 2, 3, 4])
402
403         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
404             ATYPE_NORMAL_ACCOUNT)
405         attributes = ["sAMAccountName", "displayName", "description"]
406         self._test_QueryDisplayInfo(
407             4, check_results, select, attributes, self.create_users)
408
409         self.delete_dns(dns)
410
411     # Test for QueryDisplayInfo, Level 5
412     # Returns the sAMAccountName (as an ASCII string)
413     # for all the groups.
414     #
415     def test_QueryDisplayInfo_level_5(self):
416         def check_results(expected, actual):
417             # Assume the QueryDisplayInfo and ldb.search return their results
418             # in the same order
419             for (e, a) in zip(expected, actual):
420                 self.assertTrue(isinstance(a, samr.DispEntryAscii))
421                 self.assertTrue(
422                     isinstance(a.account_name, lsa.AsciiStringLarge))
423                 self.assertEquals(
424                     str(e["sAMAccountName"]), str(a.account_name.string))
425
426         # Create four groups
427         # to ensure that we have the minimum needed for the tests.
428         dns = self.create_groups([1, 2, 3, 4])
429
430         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
431             GTYPE_SECURITY_UNIVERSAL_GROUP,
432             GTYPE_SECURITY_GLOBAL_GROUP)
433         attributes = ["sAMAccountName", "description"]
434         self._test_QueryDisplayInfo(
435             5, check_results, select, attributes, self.create_groups)
436
437         self.delete_dns(dns)