test samr: Extra tests for samr_EnumDomainGroups
[amitay/samba.git] / python / samba / tests / dcerpc / sam.py
1 # -*- coding: utf-8 -*-
2 #
3 # Unix SMB/CIFS implementation.
4 # Copyright © Jelmer Vernooij <jelmer@samba.org> 2008
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19
20 """Tests for samba.dcerpc.sam."""
21
22 from samba.dcerpc import samr, security, lsa
23 from samba.tests import RpcInterfaceTestCase
24 from samba.tests import env_loadparm, delete_force
25
26 from samba.credentials import Credentials
27 from samba.auth import system_session
28 from samba.samdb import SamDB
29 from samba.dsdb import (
30     ATYPE_NORMAL_ACCOUNT,
31     ATYPE_WORKSTATION_TRUST,
32     GTYPE_SECURITY_UNIVERSAL_GROUP,
33     GTYPE_SECURITY_GLOBAL_GROUP)
34 from samba import generate_random_password
35 from samba.ndr import ndr_unpack
36 import os
37
38
39 # FIXME: Pidl should be doing this for us
40 def toArray(handle, array, num_entries):
41     return [(entry.idx, entry.name) for entry in array.entries[:num_entries]]
42
43
44 # Extract the rid from an ldb message, assumes that the message has a
45 # objectSID attribute
46 #
47 def rid(msg):
48     sid = ndr_unpack(security.dom_sid, msg["objectSID"][0])
49     (_, rid) = sid.split()
50     return rid
51
52
53 # Calculate the request size for EnumDomainUsers and EnumDomainGroups calls
54 # to hold the specified number of entries.
55 # We use the w2k3 element size value of 54, code under test
56 # rounds this up i.e. (1+(max_size/SAMR_ENUM_USERS_MULTIPLIER))
57 #
58 def calc_max_size(num_entries):
59     return (num_entries - 1) * 54
60
61
62 class SamrTests(RpcInterfaceTestCase):
63
64     def setUp(self):
65         super(SamrTests, self).setUp()
66         self.conn = samr.samr("ncalrpc:", self.get_loadparm())
67         self.open_samdb()
68         self.open_domain_handle()
69
70     #
71     # Open the samba database
72     #
73     def open_samdb(self):
74         self.lp = env_loadparm()
75         self.domain = os.environ["DOMAIN"]
76         self.creds = Credentials()
77         self.creds.guess(self.lp)
78         self.session = system_session()
79         self.samdb = SamDB(
80             session_info=self.session, credentials=self.creds, lp=self.lp)
81
82     #
83     # Open a SAMR Domain handle
84     def open_domain_handle(self):
85         self.handle = self.conn.Connect2(
86             None, security.SEC_FLAG_MAXIMUM_ALLOWED)
87
88         self.domain_sid = self.conn.LookupDomain(
89             self.handle, lsa.String(self.domain))
90
91         self.domain_handle = self.conn.OpenDomain(
92             self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)
93
94     # Filter a list of records, removing those that are not part of the
95     # current domain.
96     #
97     def filter_domain(self, unfiltered):
98         def sid(msg):
99             sid = ndr_unpack(security.dom_sid, msg["objectSID"][0])
100             (x, _) = sid.split()
101             return x
102
103         dom_sid = security.dom_sid(self.samdb.get_domain_sid())
104         return [x for x in unfiltered if sid(x) == dom_sid]
105
106     def test_connect5(self):
107         (level, info, handle) =\
108             self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
109
110     def test_connect2(self):
111         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
112         self.assertTrue(handle is not None)
113
114     def test_EnumDomains(self):
115         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
116         toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
117         self.conn.Close(handle)
118
119     # Create groups based on the id list supplied, the id is used to
120     # form a unique name and description.
121     #
122     # returns a list of the created dn's, which can be passed to delete_dns
123     # to clean up after the test has run.
124     def create_groups(self, ids):
125         dns = []
126         for i in ids:
127             name = "SAMR_GRP%d" % i
128             dn = "cn=%s,cn=Users,%s" % (name, self.samdb.domain_dn())
129             delete_force(self.samdb, dn)
130
131             self.samdb.newgroup(name)
132             dns.append(dn)
133         return dns
134
135     # Create user accounts based on the id list supplied, the id is used to
136     # form a unique name and description.
137     #
138     # returns a list of the created dn's, which can be passed to delete_dns
139     # to clean up after the test has run.
140     def create_users(self, ids):
141         dns = []
142         for i in ids:
143             name = "SAMR_USER%d" % i
144             dn = "cn=%s,CN=USERS,%s" % (name, self.samdb.domain_dn())
145             delete_force(self.samdb, dn)
146
147             # We only need the user to exist, we don't need a password
148             self.samdb.newuser(
149                 name,
150                 password=None,
151                 setpassword=False,
152                 description="Description for " + name,
153                 givenname="given%dname" % i,
154                 surname="surname%d" % i)
155             dns.append(dn)
156         return dns
157
158     # Create computer accounts based on the id list supplied, the id is used to
159     # form a unique name and description.
160     #
161     # returns a list of the created dn's, which can be passed to delete_dns
162     # to clean up after the test has run.
163     def create_computers(self, ids):
164         dns = []
165         for i in ids:
166             name = "SAMR_CMP%d" % i
167             dn = "cn=%s,cn=COMPUTERS,%s" % (name, self.samdb.domain_dn())
168             delete_force(self.samdb, dn)
169
170             self.samdb.newcomputer(name, description="Description of " + name)
171             dns.append(dn)
172         return dns
173
174     # Delete the specified dn's.
175     #
176     # Used to clean up entries created by individual tests.
177     #
178     def delete_dns(self, dns):
179         for dn in dns:
180             delete_force(self.samdb, dn)
181
182     # Common tests for QueryDisplayInfo
183     #
184     def _test_QueryDisplayInfo(
185             self, level, check_results, select, attributes, add_elements):
186         #
187         # Get the expected results by querying the samdb database directly.
188         # We do this rather than use a list of expected results as this runs
189         # with other tests so we do not have a known fixed list of elements
190         expected = self.samdb.search(expression=select, attrs=attributes)
191         self.assertTrue(len(expected) > 0)
192
193         #
194         # Perform QueryDisplayInfo with max results greater than the expected
195         # number of results.
196         (ts, rs, actual) = self.conn.QueryDisplayInfo(
197             self.domain_handle, level, 0, 1024, 4294967295)
198
199         self.assertEquals(len(expected), ts)
200         self.assertEquals(len(expected), rs)
201         check_results(expected, actual.entries)
202
203         #
204         # Perform QueryDisplayInfo with max results set to the number of
205         # results returned from the first query, should return the same results
206         (ts1, rs1, actual1) = self.conn.QueryDisplayInfo(
207             self.domain_handle, level, 0, rs, 4294967295)
208         self.assertEquals(ts, ts1)
209         self.assertEquals(rs, rs1)
210         check_results(expected, actual1.entries)
211
212         #
213         # Perform QueryDisplayInfo and get the last two results.
214         # Note: We are assuming there are at least three entries
215         self.assertTrue(ts > 2)
216         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
217             self.domain_handle, level, (ts - 2), 2, 4294967295)
218         self.assertEquals(ts, ts2)
219         self.assertEquals(2, rs2)
220         check_results(list(expected)[-2:], actual2.entries)
221
222         #
223         # Perform QueryDisplayInfo and get the first two results.
224         # Note: We are assuming there are at least three entries
225         self.assertTrue(ts > 2)
226         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
227             self.domain_handle, level, 0, 2, 4294967295)
228         self.assertEquals(ts, ts2)
229         self.assertEquals(2, rs2)
230         check_results(list(expected)[:2], actual2.entries)
231
232         #
233         # Perform QueryDisplayInfo and get two results in the middle of the
234         # list i.e. not the first or the last entry.
235         # Note: We are assuming there are at least four entries
236         self.assertTrue(ts > 3)
237         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
238             self.domain_handle, level, 1, 2, 4294967295)
239         self.assertEquals(ts, ts2)
240         self.assertEquals(2, rs2)
241         check_results(list(expected)[1:2], actual2.entries)
242
243         #
244         # To check that cached values are being returned rather than the
245         # results being re-read from disk we add elements, and request all
246         # but the first result.
247         #
248         dns = add_elements([1000, 1002, 1003, 1004])
249
250         #
251         # Perform QueryDisplayInfo and get all but the first result.
252         # We should be using the cached results so the entries we just added
253         # should not be present
254         (ts3, rs3, actual3) = self.conn.QueryDisplayInfo(
255             self.domain_handle, level, 1, 1024, 4294967295)
256         self.assertEquals(ts, ts3)
257         self.assertEquals(len(expected) - 1, rs3)
258         check_results(list(expected)[1:], actual3.entries)
259
260         #
261         # Perform QueryDisplayInfo and get all the results.
262         # As the start index is zero we should reread the data from disk and
263         # the added entries should be there
264         new = self.samdb.search(expression=select, attrs=attributes)
265         (ts4, rs4, actual4) = self.conn.QueryDisplayInfo(
266             self.domain_handle, level, 0, 1024, 4294967295)
267         self.assertEquals(len(expected) + len(dns), ts4)
268         self.assertEquals(len(expected) + len(dns), rs4)
269         check_results(new, actual4.entries)
270
271         # Delete the added DN's and query all but the first entry.
272         # This should ensure the cached results are used and that the
273         # missing entry code is triggered.
274         self.delete_dns(dns)
275         (ts5, rs5, actual5) = self.conn.QueryDisplayInfo(
276             self.domain_handle, level, 1, 1024, 4294967295)
277         self.assertEquals(len(expected) + len(dns), ts5)
278         # The deleted results will be filtered from the result set so should
279         # be missing from the returned results.
280         # Note: depending on the GUID order, the first result in the cache may
281         #       be a deleted entry, in which case the results will contain all
282         #       the expected elements, otherwise the first expected result will
283         #       be missing.
284         if rs5 == len(expected):
285             check_results(expected, actual5.entries)
286         elif rs5 == (len(expected) - 1):
287             check_results(list(expected)[1:], actual5.entries)
288         else:
289             self.fail("Incorrect number of entries {0}".format(rs5))
290
291         #
292         # Perform QueryDisplayInfo specifying an index past the end of the
293         # available data.
294         # Should return no data.
295         (ts6, rs6, actual6) = self.conn.QueryDisplayInfo(
296             self.domain_handle, level, ts5, 1, 4294967295)
297         self.assertEquals(ts5, ts6)
298         self.assertEquals(0, rs6)
299
300         self.conn.Close(self.handle)
301
302     # Test for QueryDisplayInfo, Level 1
303     # Returns the sAMAccountName, displayName and description for all
304     # the user accounts.
305     #
306     def test_QueryDisplayInfo_level_1(self):
307         def check_results(expected, actual):
308             # Assume the QueryDisplayInfo and ldb.search return their results
309             # in the same order
310             for (e, a) in zip(expected, actual):
311                 self.assertTrue(isinstance(a, samr.DispEntryGeneral))
312                 self.assertEquals(str(e["sAMAccountName"]),
313                                   str(a.account_name))
314
315                 # The displayName and description are optional.
316                 # In the expected results they will be missing, in
317                 # samr.DispEntryGeneral the corresponding attribute will have a
318                 # length of zero.
319                 #
320                 if a.full_name.length == 0:
321                     self.assertFalse("displayName" in e)
322                 else:
323                     self.assertEquals(str(e["displayName"]), str(a.full_name))
324
325                 if a.description.length == 0:
326                     self.assertFalse("description" in e)
327                 else:
328                     self.assertEquals(str(e["description"]),
329                                       str(a.description))
330         # Create four user accounts
331         # to ensure that we have the minimum needed for the tests.
332         dns = self.create_users([1, 2, 3, 4])
333
334         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
335             ATYPE_NORMAL_ACCOUNT)
336         attributes = ["sAMAccountName", "displayName", "description"]
337         self._test_QueryDisplayInfo(
338             1, check_results, select, attributes, self.create_users)
339
340         self.delete_dns(dns)
341
342     # Test for QueryDisplayInfo, Level 2
343     # Returns the sAMAccountName and description for all
344     # the computer accounts.
345     #
346     def test_QueryDisplayInfo_level_2(self):
347         def check_results(expected, actual):
348             # Assume the QueryDisplayInfo and ldb.search return their results
349             # in the same order
350             for (e, a) in zip(expected, actual):
351                 self.assertTrue(isinstance(a, samr.DispEntryFull))
352                 self.assertEquals(str(e["sAMAccountName"]),
353                                   str(a.account_name))
354
355                 # The description is optional.
356                 # In the expected results they will be missing, in
357                 # samr.DispEntryGeneral the corresponding attribute will have a
358                 # length of zero.
359                 #
360                 if a.description.length == 0:
361                     self.assertFalse("description" in e)
362                 else:
363                     self.assertEquals(str(e["description"]),
364                                       str(a.description))
365
366         # Create four computer accounts
367         # to ensure that we have the minimum needed for the tests.
368         dns = self.create_computers([1, 2, 3, 4])
369
370         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
371             ATYPE_WORKSTATION_TRUST)
372         attributes = ["sAMAccountName", "description"]
373         self._test_QueryDisplayInfo(
374             2, check_results, select, attributes, self.create_computers)
375
376         self.delete_dns(dns)
377
378     # Test for QueryDisplayInfo, Level 3
379     # Returns the sAMAccountName and description for all
380     # the groups.
381     #
382     def test_QueryDisplayInfo_level_3(self):
383         def check_results(expected, actual):
384             # Assume the QueryDisplayInfo and ldb.search return their results
385             # in the same order
386             for (e, a) in zip(expected, actual):
387                 self.assertTrue(isinstance(a, samr.DispEntryFullGroup))
388                 self.assertEquals(str(e["sAMAccountName"]),
389                                   str(a.account_name))
390
391                 # The description is optional.
392                 # In the expected results they will be missing, in
393                 # samr.DispEntryGeneral the corresponding attribute will have a
394                 # length of zero.
395                 #
396                 if a.description.length == 0:
397                     self.assertFalse("description" in e)
398                 else:
399                     self.assertEquals(str(e["description"]),
400                                       str(a.description))
401
402         # Create four groups
403         # to ensure that we have the minimum needed for the tests.
404         dns = self.create_groups([1, 2, 3, 4])
405
406         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
407             GTYPE_SECURITY_UNIVERSAL_GROUP,
408             GTYPE_SECURITY_GLOBAL_GROUP)
409         attributes = ["sAMAccountName", "description"]
410         self._test_QueryDisplayInfo(
411             3, check_results, select, attributes, self.create_groups)
412
413         self.delete_dns(dns)
414
415     # Test for QueryDisplayInfo, Level 4
416     # Returns the sAMAccountName (as an ASCII string)
417     # for all the user accounts.
418     #
419     def test_QueryDisplayInfo_level_4(self):
420         def check_results(expected, actual):
421             # Assume the QueryDisplayInfo and ldb.search return their results
422             # in the same order
423             for (e, a) in zip(expected, actual):
424                 self.assertTrue(isinstance(a, samr.DispEntryAscii))
425                 self.assertTrue(
426                     isinstance(a.account_name, lsa.AsciiStringLarge))
427                 self.assertEquals(
428                     str(e["sAMAccountName"]), str(a.account_name.string))
429
430         # Create four user accounts
431         # to ensure that we have the minimum needed for the tests.
432         dns = self.create_users([1, 2, 3, 4])
433
434         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
435             ATYPE_NORMAL_ACCOUNT)
436         attributes = ["sAMAccountName", "displayName", "description"]
437         self._test_QueryDisplayInfo(
438             4, check_results, select, attributes, self.create_users)
439
440         self.delete_dns(dns)
441
442     # Test for QueryDisplayInfo, Level 5
443     # Returns the sAMAccountName (as an ASCII string)
444     # for all the groups.
445     #
446     def test_QueryDisplayInfo_level_5(self):
447         def check_results(expected, actual):
448             # Assume the QueryDisplayInfo and ldb.search return their results
449             # in the same order
450             for (e, a) in zip(expected, actual):
451                 self.assertTrue(isinstance(a, samr.DispEntryAscii))
452                 self.assertTrue(
453                     isinstance(a.account_name, lsa.AsciiStringLarge))
454                 self.assertEquals(
455                     str(e["sAMAccountName"]), str(a.account_name.string))
456
457         # Create four groups
458         # to ensure that we have the minimum needed for the tests.
459         dns = self.create_groups([1, 2, 3, 4])
460
461         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
462             GTYPE_SECURITY_UNIVERSAL_GROUP,
463             GTYPE_SECURITY_GLOBAL_GROUP)
464         attributes = ["sAMAccountName", "description"]
465         self._test_QueryDisplayInfo(
466             5, check_results, select, attributes, self.create_groups)
467
468         self.delete_dns(dns)
469
470     def test_EnumDomainGroups(self):
471         def check_results(expected, actual):
472             for (e, a) in zip(expected, actual):
473                 self.assertTrue(isinstance(a, samr.SamEntry))
474                 self.assertEquals(
475                     str(e["sAMAccountName"]), str(a.name.string))
476
477         # Create four groups
478         # to ensure that we have the minimum needed for the tests.
479         dns = self.create_groups([1, 2, 3, 4])
480
481         #
482         # Get the expected results by querying the samdb database directly.
483         # We do this rather than use a list of expected results as this runs
484         # with other tests so we do not have a known fixed list of elements
485         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
486             GTYPE_SECURITY_UNIVERSAL_GROUP,
487             GTYPE_SECURITY_GLOBAL_GROUP)
488         attributes = ["sAMAccountName", "objectSID"]
489         unfiltered = self.samdb.search(expression=select, attrs=attributes)
490         filtered = self.filter_domain(unfiltered)
491         self.assertTrue(len(filtered) > 4)
492
493         # Sort the expected results by rid
494         expected = sorted(list(filtered), key=rid)
495
496         #
497         # Perform EnumDomainGroups with max size greater than the expected
498         # number of results. Allow for an extra 10 entries
499         #
500         max_size = calc_max_size(len(expected) + 10)
501         (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
502             self.domain_handle, 0, max_size)
503         self.assertEquals(len(expected), num_entries)
504         check_results(expected, actual.entries)
505
506         #
507         # Perform EnumDomainGroups with size set to so that it contains
508         # 4 entries.
509         #
510         max_size = calc_max_size(4)
511         (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
512             self.domain_handle, 0, max_size)
513         self.assertEquals(4, num_entries)
514         check_results(expected[:4], actual.entries)
515
516         #
517         # Try calling with resume_handle greater than number of entries
518         # Should return no results and a resume handle of 0
519         max_size = calc_max_size(1)
520         rh = len(expected)
521         self.conn.Close(self.handle)
522         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
523             self.domain_handle, rh, max_size)
524
525         self.assertEquals(0, num_entries)
526         self.assertEquals(0, resume_handle)
527
528         #
529         # Enumerate through the domain groups one element at a time.
530         #
531         max_size = calc_max_size(1)
532         actual = []
533         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
534             self.domain_handle, 0, max_size)
535         while resume_handle:
536             self.assertEquals(1, num_entries)
537             actual.append(a.entries[0])
538             (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
539                 self.domain_handle, resume_handle, max_size)
540         if num_entries:
541             actual.append(a.entries[0])
542
543         #
544         # Check that the cached results are being returned.
545         # Obtain a new resume_handle and insert new entries into the
546         # into the DB
547         #
548         actual = []
549         max_size = calc_max_size(1)
550         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
551             self.domain_handle, 0, max_size)
552         extra_dns = self.create_groups([1000, 1002, 1003, 1004])
553         while resume_handle:
554             self.assertEquals(1, num_entries)
555             actual.append(a.entries[0])
556             (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
557                 self.domain_handle, resume_handle, max_size)
558         if num_entries:
559             actual.append(a.entries[0])
560
561         self.assertEquals(len(expected), len(actual))
562         check_results(expected, actual)
563
564         #
565         # Perform EnumDomainGroups, we should read the newly added domains
566         #
567         max_size = calc_max_size(len(expected) + len(extra_dns) + 10)
568         (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
569             self.domain_handle, 0, max_size)
570         self.assertEquals(len(expected) + len(extra_dns), num_entries)
571
572         #
573         # Get a new expected result set by querying the database directly
574         unfiltered01 = self.samdb.search(expression=select, attrs=attributes)
575         filtered01 = self.filter_domain(unfiltered01)
576         self.assertTrue(len(filtered01) > len(expected))
577
578         # Sort the expected results by rid
579         expected01 = sorted(list(filtered01), key=rid)
580
581         #
582         # Now check that we read the new entries.
583         #
584         check_results(expected01, actual.entries)
585
586         #
587         # Check that deleted results are handled correctly.
588         # Obtain a new resume_handle and delete entries from the DB.
589         #
590         actual = []
591         max_size = calc_max_size(1)
592         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
593             self.domain_handle, 0, max_size)
594         self.delete_dns(extra_dns)
595         while resume_handle and num_entries:
596             self.assertEquals(1, num_entries)
597             actual.append(a.entries[0])
598             (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
599                 self.domain_handle, resume_handle, max_size)
600         if num_entries:
601             actual.append(a.entries[0])
602
603         self.assertEquals(len(expected), len(actual))
604         check_results(expected, actual)
605
606         self.delete_dns(dns)