python: tests: update all super calls to python 3 style in tests
[samba.git] / python / samba / tests / dcerpc / sam.py
1 # -*- coding: utf-8 -*-
2 #
3 # Unix SMB/CIFS implementation.
4 # Copyright © Jelmer Vernooij <jelmer@samba.org> 2008
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19
20 """Tests for samba.dcerpc.sam."""
21
22 from samba.dcerpc import samr, security, lsa
23 from samba.dcerpc.samr import DomainGeneralInformation
24 from samba.tests import RpcInterfaceTestCase
25 from samba.tests import env_loadparm, delete_force
26
27 from samba.credentials import Credentials
28 from samba.auth import system_session
29 from samba.samdb import SamDB
30 from samba.dsdb import (
31     ATYPE_NORMAL_ACCOUNT,
32     ATYPE_WORKSTATION_TRUST,
33     GTYPE_SECURITY_UNIVERSAL_GROUP,
34     GTYPE_SECURITY_GLOBAL_GROUP)
35 from samba import generate_random_password
36 from samba.ndr import ndr_unpack
37 import os
38
39
40 # FIXME: Pidl should be doing this for us
41 def toArray(handle, array, num_entries):
42     return [(entry.idx, entry.name) for entry in array.entries[:num_entries]]
43
44
45 # Extract the rid from an ldb message, assumes that the message has a
46 # objectSID attribute
47 #
48 def rid(msg):
49     sid = ndr_unpack(security.dom_sid, msg["objectSID"][0])
50     (_, rid) = sid.split()
51     return rid
52
53
54 # Calculate the request size for EnumDomainUsers and EnumDomainGroups calls
55 # to hold the specified number of entries.
56 # We use the w2k3 element size value of 54, code under test
57 # rounds this up i.e. (1+(max_size/SAMR_ENUM_USERS_MULTIPLIER))
58 #
59 def calc_max_size(num_entries):
60     return (num_entries - 1) * 54
61
62
63 class SamrTests(RpcInterfaceTestCase):
64
65     def setUp(self):
66         super().setUp()
67         self.conn = samr.samr("ncalrpc:", self.get_loadparm())
68         self.open_samdb()
69         self.open_domain_handle()
70
71     #
72     # Open the samba database
73     #
74     def open_samdb(self):
75         self.lp = env_loadparm()
76         self.domain = os.environ["DOMAIN"]
77         self.creds = Credentials()
78         self.creds.guess(self.lp)
79         self.session = system_session()
80         self.samdb = SamDB(
81             session_info=self.session, credentials=self.creds, lp=self.lp)
82
83     #
84     # Open a SAMR Domain handle
85     def open_domain_handle(self):
86         self.handle = self.conn.Connect2(
87             None, security.SEC_FLAG_MAXIMUM_ALLOWED)
88
89         self.domain_sid = self.conn.LookupDomain(
90             self.handle, lsa.String(self.domain))
91
92         self.domain_handle = self.conn.OpenDomain(
93             self.handle, security.SEC_FLAG_MAXIMUM_ALLOWED, self.domain_sid)
94
95     # Filter a list of records, removing those that are not part of the
96     # current domain.
97     #
98     def filter_domain(self, unfiltered):
99         def sid(msg):
100             sid = ndr_unpack(security.dom_sid, msg["objectSID"][0])
101             (x, _) = sid.split()
102             return x
103
104         dom_sid = security.dom_sid(self.samdb.get_domain_sid())
105         return [x for x in unfiltered if sid(x) == dom_sid]
106
107     def test_connect5(self):
108         (level, info, handle) =\
109             self.conn.Connect5(None, 0, 1, samr.ConnectInfo1())
110
111     def test_connect2(self):
112         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
113         self.assertTrue(handle is not None)
114
115     def test_EnumDomains(self):
116         handle = self.conn.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
117         toArray(*self.conn.EnumDomains(handle, 0, 4294967295))
118         self.conn.Close(handle)
119
120     # Create groups based on the id list supplied, the id is used to
121     # form a unique name and description.
122     #
123     # returns a list of the created dn's, which can be passed to delete_dns
124     # to clean up after the test has run.
125     def create_groups(self, ids):
126         dns = []
127         for i in ids:
128             name = "SAMR_GRP%d" % i
129             dn = "cn=%s,cn=Users,%s" % (name, self.samdb.domain_dn())
130             delete_force(self.samdb, dn)
131
132             self.samdb.newgroup(name)
133             dns.append(dn)
134         return dns
135
136     # Create user accounts based on the id list supplied, the id is used to
137     # form a unique name and description.
138     #
139     # returns a list of the created dn's, which can be passed to delete_dns
140     # to clean up after the test has run.
141     def create_users(self, ids):
142         dns = []
143         for i in ids:
144             name = "SAMR_USER%d" % i
145             dn = "cn=%s,CN=USERS,%s" % (name, self.samdb.domain_dn())
146             delete_force(self.samdb, dn)
147
148             # We only need the user to exist, we don't need a password
149             self.samdb.newuser(
150                 name,
151                 password=None,
152                 setpassword=False,
153                 description="Description for " + name,
154                 givenname="given%dname" % i,
155                 surname="surname%d" % i)
156             dns.append(dn)
157         return dns
158
159     # Create computer accounts based on the id list supplied, the id is used to
160     # form a unique name and description.
161     #
162     # returns a list of the created dn's, which can be passed to delete_dns
163     # to clean up after the test has run.
164     def create_computers(self, ids):
165         dns = []
166         for i in ids:
167             name = "SAMR_CMP%d" % i
168             dn = "cn=%s,cn=COMPUTERS,%s" % (name, self.samdb.domain_dn())
169             delete_force(self.samdb, dn)
170
171             self.samdb.newcomputer(name, description="Description of " + name)
172             dns.append(dn)
173         return dns
174
175     # Delete the specified dn's.
176     #
177     # Used to clean up entries created by individual tests.
178     #
179     def delete_dns(self, dns):
180         for dn in dns:
181             delete_force(self.samdb, dn)
182
183     # Common tests for QueryDisplayInfo
184     #
185     def _test_QueryDisplayInfo(
186             self, level, check_results, select, attributes, add_elements):
187         #
188         # Get the expected results by querying the samdb database directly.
189         # We do this rather than use a list of expected results as this runs
190         # with other tests so we do not have a known fixed list of elements
191         expected = self.samdb.search(expression=select, attrs=attributes)
192         self.assertTrue(len(expected) > 0)
193
194         #
195         # Perform QueryDisplayInfo with max results greater than the expected
196         # number of results.
197         (ts, rs, actual) = self.conn.QueryDisplayInfo(
198             self.domain_handle, level, 0, 1024, 4294967295)
199
200         self.assertEqual(len(expected), ts)
201         self.assertEqual(len(expected), rs)
202         check_results(expected, actual.entries)
203
204         #
205         # Perform QueryDisplayInfo with max results set to the number of
206         # results returned from the first query, should return the same results
207         (ts1, rs1, actual1) = self.conn.QueryDisplayInfo(
208             self.domain_handle, level, 0, rs, 4294967295)
209         self.assertEqual(ts, ts1)
210         self.assertEqual(rs, rs1)
211         check_results(expected, actual1.entries)
212
213         #
214         # Perform QueryDisplayInfo and get the last two results.
215         # Note: We are assuming there are at least three entries
216         self.assertTrue(ts > 2)
217         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
218             self.domain_handle, level, (ts - 2), 2, 4294967295)
219         self.assertEqual(ts, ts2)
220         self.assertEqual(2, rs2)
221         check_results(list(expected)[-2:], actual2.entries)
222
223         #
224         # Perform QueryDisplayInfo and get the first two results.
225         # Note: We are assuming there are at least three entries
226         self.assertTrue(ts > 2)
227         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
228             self.domain_handle, level, 0, 2, 4294967295)
229         self.assertEqual(ts, ts2)
230         self.assertEqual(2, rs2)
231         check_results(list(expected)[:2], actual2.entries)
232
233         #
234         # Perform QueryDisplayInfo and get two results in the middle of the
235         # list i.e. not the first or the last entry.
236         # Note: We are assuming there are at least four entries
237         self.assertTrue(ts > 3)
238         (ts2, rs2, actual2) = self.conn.QueryDisplayInfo(
239             self.domain_handle, level, 1, 2, 4294967295)
240         self.assertEqual(ts, ts2)
241         self.assertEqual(2, rs2)
242         check_results(list(expected)[1:2], actual2.entries)
243
244         #
245         # To check that cached values are being returned rather than the
246         # results being re-read from disk we add elements, and request all
247         # but the first result.
248         #
249         dns = add_elements([1000, 1002, 1003, 1004])
250
251         #
252         # Perform QueryDisplayInfo and get all but the first result.
253         # We should be using the cached results so the entries we just added
254         # should not be present
255         (ts3, rs3, actual3) = self.conn.QueryDisplayInfo(
256             self.domain_handle, level, 1, 1024, 4294967295)
257         self.assertEqual(ts, ts3)
258         self.assertEqual(len(expected) - 1, rs3)
259         check_results(list(expected)[1:], actual3.entries)
260
261         #
262         # Perform QueryDisplayInfo and get all the results.
263         # As the start index is zero we should reread the data from disk and
264         # the added entries should be there
265         new = self.samdb.search(expression=select, attrs=attributes)
266         (ts4, rs4, actual4) = self.conn.QueryDisplayInfo(
267             self.domain_handle, level, 0, 1024, 4294967295)
268         self.assertEqual(len(expected) + len(dns), ts4)
269         self.assertEqual(len(expected) + len(dns), rs4)
270         check_results(new, actual4.entries)
271
272         # Delete the added DN's and query all but the first entry.
273         # This should ensure the cached results are used and that the
274         # missing entry code is triggered.
275         self.delete_dns(dns)
276         (ts5, rs5, actual5) = self.conn.QueryDisplayInfo(
277             self.domain_handle, level, 1, 1024, 4294967295)
278         self.assertEqual(len(expected) + len(dns), ts5)
279         # The deleted results will be filtered from the result set so should
280         # be missing from the returned results.
281         # Note: depending on the GUID order, the first result in the cache may
282         #       be a deleted entry, in which case the results will contain all
283         #       the expected elements, otherwise the first expected result will
284         #       be missing.
285         if rs5 == len(expected):
286             check_results(expected, actual5.entries)
287         elif rs5 == (len(expected) - 1):
288             check_results(list(expected)[1:], actual5.entries)
289         else:
290             self.fail("Incorrect number of entries {0}".format(rs5))
291
292         #
293         # Perform QueryDisplayInfo specifying an index past the end of the
294         # available data.
295         # Should return no data.
296         (ts6, rs6, actual6) = self.conn.QueryDisplayInfo(
297             self.domain_handle, level, ts5, 1, 4294967295)
298         self.assertEqual(ts5, ts6)
299         self.assertEqual(0, rs6)
300
301         self.conn.Close(self.handle)
302
303     # Test for QueryDisplayInfo, Level 1
304     # Returns the sAMAccountName, displayName and description for all
305     # the user accounts.
306     #
307     def test_QueryDisplayInfo_level_1(self):
308         def check_results(expected, actual):
309             # Assume the QueryDisplayInfo and ldb.search return their results
310             # in the same order
311             for (e, a) in zip(expected, actual):
312                 self.assertTrue(isinstance(a, samr.DispEntryGeneral))
313                 self.assertEqual(str(e["sAMAccountName"]),
314                                   str(a.account_name))
315
316                 # The displayName and description are optional.
317                 # In the expected results they will be missing, in
318                 # samr.DispEntryGeneral the corresponding attribute will have a
319                 # length of zero.
320                 #
321                 if a.full_name.length == 0:
322                     self.assertFalse("displayName" in e)
323                 else:
324                     self.assertEqual(str(e["displayName"]), str(a.full_name))
325
326                 if a.description.length == 0:
327                     self.assertFalse("description" in e)
328                 else:
329                     self.assertEqual(str(e["description"]),
330                                       str(a.description))
331         # Create four user accounts
332         # to ensure that we have the minimum needed for the tests.
333         dns = self.create_users([1, 2, 3, 4])
334
335         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
336             ATYPE_NORMAL_ACCOUNT)
337         attributes = ["sAMAccountName", "displayName", "description"]
338         self._test_QueryDisplayInfo(
339             1, check_results, select, attributes, self.create_users)
340
341         self.delete_dns(dns)
342
343     # Test for QueryDisplayInfo, Level 2
344     # Returns the sAMAccountName and description for all
345     # the computer accounts.
346     #
347     def test_QueryDisplayInfo_level_2(self):
348         def check_results(expected, actual):
349             # Assume the QueryDisplayInfo and ldb.search return their results
350             # in the same order
351             for (e, a) in zip(expected, actual):
352                 self.assertTrue(isinstance(a, samr.DispEntryFull))
353                 self.assertEqual(str(e["sAMAccountName"]),
354                                   str(a.account_name))
355
356                 # The description is optional.
357                 # In the expected results they will be missing, in
358                 # samr.DispEntryGeneral the corresponding attribute will have a
359                 # length of zero.
360                 #
361                 if a.description.length == 0:
362                     self.assertFalse("description" in e)
363                 else:
364                     self.assertEqual(str(e["description"]),
365                                       str(a.description))
366
367         # Create four computer accounts
368         # to ensure that we have the minimum needed for the tests.
369         dns = self.create_computers([1, 2, 3, 4])
370
371         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
372             ATYPE_WORKSTATION_TRUST)
373         attributes = ["sAMAccountName", "description"]
374         self._test_QueryDisplayInfo(
375             2, check_results, select, attributes, self.create_computers)
376
377         self.delete_dns(dns)
378
379     # Test for QueryDisplayInfo, Level 3
380     # Returns the sAMAccountName and description for all
381     # the groups.
382     #
383     def test_QueryDisplayInfo_level_3(self):
384         def check_results(expected, actual):
385             # Assume the QueryDisplayInfo and ldb.search return their results
386             # in the same order
387             for (e, a) in zip(expected, actual):
388                 self.assertTrue(isinstance(a, samr.DispEntryFullGroup))
389                 self.assertEqual(str(e["sAMAccountName"]),
390                                   str(a.account_name))
391
392                 # The description is optional.
393                 # In the expected results they will be missing, in
394                 # samr.DispEntryGeneral the corresponding attribute will have a
395                 # length of zero.
396                 #
397                 if a.description.length == 0:
398                     self.assertFalse("description" in e)
399                 else:
400                     self.assertEqual(str(e["description"]),
401                                       str(a.description))
402
403         # Create four groups
404         # to ensure that we have the minimum needed for the tests.
405         dns = self.create_groups([1, 2, 3, 4])
406
407         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
408             GTYPE_SECURITY_UNIVERSAL_GROUP,
409             GTYPE_SECURITY_GLOBAL_GROUP)
410         attributes = ["sAMAccountName", "description"]
411         self._test_QueryDisplayInfo(
412             3, check_results, select, attributes, self.create_groups)
413
414         self.delete_dns(dns)
415
416     # Test for QueryDisplayInfo, Level 4
417     # Returns the sAMAccountName (as an ASCII string)
418     # for all the user accounts.
419     #
420     def test_QueryDisplayInfo_level_4(self):
421         def check_results(expected, actual):
422             # Assume the QueryDisplayInfo and ldb.search return their results
423             # in the same order
424             for (e, a) in zip(expected, actual):
425                 self.assertTrue(isinstance(a, samr.DispEntryAscii))
426                 self.assertTrue(
427                     isinstance(a.account_name, lsa.AsciiStringLarge))
428                 self.assertEqual(
429                     str(e["sAMAccountName"]), str(a.account_name.string))
430
431         # Create four user accounts
432         # to ensure that we have the minimum needed for the tests.
433         dns = self.create_users([1, 2, 3, 4])
434
435         select = "(&(objectclass=user)(sAMAccountType={0}))".format(
436             ATYPE_NORMAL_ACCOUNT)
437         attributes = ["sAMAccountName", "displayName", "description"]
438         self._test_QueryDisplayInfo(
439             4, check_results, select, attributes, self.create_users)
440
441         self.delete_dns(dns)
442
443     # Test for QueryDisplayInfo, Level 5
444     # Returns the sAMAccountName (as an ASCII string)
445     # for all the groups.
446     #
447     def test_QueryDisplayInfo_level_5(self):
448         def check_results(expected, actual):
449             # Assume the QueryDisplayInfo and ldb.search return their results
450             # in the same order
451             for (e, a) in zip(expected, actual):
452                 self.assertTrue(isinstance(a, samr.DispEntryAscii))
453                 self.assertTrue(
454                     isinstance(a.account_name, lsa.AsciiStringLarge))
455                 self.assertEqual(
456                     str(e["sAMAccountName"]), str(a.account_name.string))
457
458         # Create four groups
459         # to ensure that we have the minimum needed for the tests.
460         dns = self.create_groups([1, 2, 3, 4])
461
462         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
463             GTYPE_SECURITY_UNIVERSAL_GROUP,
464             GTYPE_SECURITY_GLOBAL_GROUP)
465         attributes = ["sAMAccountName", "description"]
466         self._test_QueryDisplayInfo(
467             5, check_results, select, attributes, self.create_groups)
468
469         self.delete_dns(dns)
470
471     def test_EnumDomainGroups(self):
472         def check_results(expected, actual):
473             for (e, a) in zip(expected, actual):
474                 self.assertTrue(isinstance(a, samr.SamEntry))
475                 self.assertEqual(
476                     str(e["sAMAccountName"]), str(a.name.string))
477
478         # Create four groups
479         # to ensure that we have the minimum needed for the tests.
480         dns = self.create_groups([1, 2, 3, 4])
481
482         #
483         # Get the expected results by querying the samdb database directly.
484         # We do this rather than use a list of expected results as this runs
485         # with other tests so we do not have a known fixed list of elements
486         select = "(&(|(groupType=%d)(groupType=%d))(objectClass=group))" % (
487             GTYPE_SECURITY_UNIVERSAL_GROUP,
488             GTYPE_SECURITY_GLOBAL_GROUP)
489         attributes = ["sAMAccountName", "objectSID"]
490         unfiltered = self.samdb.search(expression=select, attrs=attributes)
491         filtered = self.filter_domain(unfiltered)
492         self.assertTrue(len(filtered) > 4)
493
494         # Sort the expected results by rid
495         expected = sorted(list(filtered), key=rid)
496
497         #
498         # Perform EnumDomainGroups with max size greater than the expected
499         # number of results. Allow for an extra 10 entries
500         #
501         max_size = calc_max_size(len(expected) + 10)
502         (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
503             self.domain_handle, 0, max_size)
504         self.assertEqual(len(expected), num_entries)
505         check_results(expected, actual.entries)
506
507         #
508         # Perform EnumDomainGroups with size set to so that it contains
509         # 4 entries.
510         #
511         max_size = calc_max_size(4)
512         (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
513             self.domain_handle, 0, max_size)
514         self.assertEqual(4, num_entries)
515         check_results(expected[:4], actual.entries)
516
517         #
518         # Try calling with resume_handle greater than number of entries
519         # Should return no results and a resume handle of 0
520         max_size = calc_max_size(1)
521         rh = len(expected)
522         self.conn.Close(self.handle)
523         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
524             self.domain_handle, rh, max_size)
525
526         self.assertEqual(0, num_entries)
527         self.assertEqual(0, resume_handle)
528
529         #
530         # Enumerate through the domain groups one element at a time.
531         #
532         max_size = calc_max_size(1)
533         actual = []
534         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
535             self.domain_handle, 0, max_size)
536         while resume_handle:
537             self.assertEqual(1, num_entries)
538             actual.append(a.entries[0])
539             (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
540                 self.domain_handle, resume_handle, max_size)
541         if num_entries:
542             actual.append(a.entries[0])
543
544         #
545         # Check that the cached results are being returned.
546         # Obtain a new resume_handle and insert new entries into the
547         # into the DB
548         #
549         actual = []
550         max_size = calc_max_size(1)
551         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
552             self.domain_handle, 0, max_size)
553         extra_dns = self.create_groups([1000, 1002, 1003, 1004])
554         while resume_handle:
555             self.assertEqual(1, num_entries)
556             actual.append(a.entries[0])
557             (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
558                 self.domain_handle, resume_handle, max_size)
559         if num_entries:
560             actual.append(a.entries[0])
561
562         self.assertEqual(len(expected), len(actual))
563         check_results(expected, actual)
564
565         #
566         # Perform EnumDomainGroups, we should read the newly added domains
567         #
568         max_size = calc_max_size(len(expected) + len(extra_dns) + 10)
569         (resume_handle, actual, num_entries) = self.conn.EnumDomainGroups(
570             self.domain_handle, 0, max_size)
571         self.assertEqual(len(expected) + len(extra_dns), num_entries)
572
573         #
574         # Get a new expected result set by querying the database directly
575         unfiltered01 = self.samdb.search(expression=select, attrs=attributes)
576         filtered01 = self.filter_domain(unfiltered01)
577         self.assertTrue(len(filtered01) > len(expected))
578
579         # Sort the expected results by rid
580         expected01 = sorted(list(filtered01), key=rid)
581
582         #
583         # Now check that we read the new entries.
584         #
585         check_results(expected01, actual.entries)
586
587         #
588         # Check that deleted results are handled correctly.
589         # Obtain a new resume_handle and delete entries from the DB.
590         #
591         actual = []
592         max_size = calc_max_size(1)
593         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
594             self.domain_handle, 0, max_size)
595         self.delete_dns(extra_dns)
596         while resume_handle and num_entries:
597             self.assertEqual(1, num_entries)
598             actual.append(a.entries[0])
599             (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
600                 self.domain_handle, resume_handle, max_size)
601         if num_entries:
602             actual.append(a.entries[0])
603
604         self.assertEqual(len(expected), len(actual))
605         check_results(expected, actual)
606
607         self.delete_dns(dns)
608
609     def test_EnumDomainUsers(self):
610         def check_results(expected, actual):
611             for (e, a) in zip(expected, actual):
612                 self.assertTrue(isinstance(a, samr.SamEntry))
613                 self.assertEqual(
614                     str(e["sAMAccountName"]), str(a.name.string))
615
616         # Create four users
617         # to ensure that we have the minimum needed for the tests.
618         dns = self.create_users([1, 2, 3, 4])
619
620         #
621         # Get the expected results by querying the samdb database directly.
622         # We do this rather than use a list of expected results as this runs
623         # with other tests so we do not have a known fixed list of elements
624         select = "(objectClass=user)"
625         attributes = ["sAMAccountName", "objectSID", "userAccountConrol"]
626         unfiltered = self.samdb.search(expression=select, attrs=attributes)
627         filtered = self.filter_domain(unfiltered)
628         self.assertTrue(len(filtered) > 4)
629
630         # Sort the expected results by rid
631         expected = sorted(list(filtered), key=rid)
632
633         #
634         # Perform EnumDomainUsers with max_size greater than required for the
635         # expected number of results. We should get all the results.
636         #
637         max_size = calc_max_size(len(expected) + 10)
638         (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
639             self.domain_handle, 0, 0, max_size)
640         self.assertEqual(len(expected), num_entries)
641         check_results(expected, actual.entries)
642
643         #
644         # Perform EnumDomainUsers with size set to so that it contains
645         # 4 entries.
646         max_size = calc_max_size(4)
647         (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
648             self.domain_handle, 0, 0, max_size)
649         self.assertEqual(4, num_entries)
650         check_results(expected[:4], actual.entries)
651
652         #
653         # Try calling with resume_handle greater than number of entries
654         # Should return no results and a resume handle of 0
655         rh = len(expected)
656         max_size = calc_max_size(1)
657         self.conn.Close(self.handle)
658         (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
659             self.domain_handle, rh, 0, max_size)
660
661         self.assertEqual(0, num_entries)
662         self.assertEqual(0, resume_handle)
663
664         #
665         # Enumerate through the domain users one element at a time.
666         # We should get all the results.
667         #
668         actual = []
669         max_size = calc_max_size(1)
670         (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
671             self.domain_handle, 0, 0, max_size)
672         while resume_handle:
673             self.assertEqual(1, num_entries)
674             actual.append(a.entries[0])
675             (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
676                 self.domain_handle, resume_handle, 0, max_size)
677         if num_entries:
678             actual.append(a.entries[0])
679
680         self.assertEqual(len(expected), len(actual))
681         check_results(expected, actual)
682
683         #
684         # Check that the cached results are being returned.
685         # Obtain a new resume_handle and insert new entries into the
686         # into the DB. As the entries were added after the results were cached
687         # they should not show up in the returned results.
688         #
689         actual = []
690         max_size = calc_max_size(1)
691         (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
692             self.domain_handle, 0, 0, max_size)
693         extra_dns = self.create_users([1000, 1002, 1003, 1004])
694         while resume_handle:
695             self.assertEqual(1, num_entries)
696             actual.append(a.entries[0])
697             (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
698                 self.domain_handle, resume_handle, 0, max_size)
699         if num_entries:
700             actual.append(a.entries[0])
701
702         self.assertEqual(len(expected), len(actual))
703         check_results(expected, actual)
704
705         #
706         # Perform EnumDomainUsers, we should read the newly added groups
707         # As resume_handle is zero, the results will be read from disk.
708         #
709         max_size = calc_max_size(len(expected) + len(extra_dns) + 10)
710         (resume_handle, actual, num_entries) = self.conn.EnumDomainUsers(
711             self.domain_handle, 0, 0, max_size)
712         self.assertEqual(len(expected) + len(extra_dns), num_entries)
713
714         #
715         # Get a new expected result set by querying the database directly
716         unfiltered01 = self.samdb.search(expression=select, attrs=attributes)
717         filtered01 = self.filter_domain(unfiltered01)
718         self.assertTrue(len(filtered01) > len(expected))
719
720         # Sort the expected results by rid
721         expected01 = sorted(list(filtered01), key=rid)
722
723         #
724         # Now check that we read the new entries.
725         #
726         self.assertEqual(len(expected01), num_entries)
727         check_results(expected01, actual.entries)
728
729         self.delete_dns(dns + extra_dns)
730
731     def test_DomGeneralInformation_num_users(self):
732         info = self.conn.QueryDomainInfo(
733             self.domain_handle, DomainGeneralInformation)
734         #
735         # Enumerate through all the domain users and compare the number
736         # returned against QueryDomainInfo they should be the same
737         max_size = calc_max_size(1)
738         (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
739             self.domain_handle, 0, 0, max_size)
740         count = num_entries
741         while resume_handle:
742             self.assertEqual(1, num_entries)
743             (resume_handle, a, num_entries) = self.conn.EnumDomainUsers(
744                 self.domain_handle, resume_handle, 0, max_size)
745             count += num_entries
746
747         self.assertEqual(count, info.num_users)
748
749     def test_DomGeneralInformation_num_groups(self):
750         info = self.conn.QueryDomainInfo(
751             self.domain_handle, DomainGeneralInformation)
752         #
753         # Enumerate through all the domain groups and compare the number
754         # returned against QueryDomainInfo they should be the same
755         max_size = calc_max_size(1)
756         (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
757             self.domain_handle, 0, max_size)
758         count = num_entries
759         while resume_handle:
760             self.assertEqual(1, num_entries)
761             (resume_handle, a, num_entries) = self.conn.EnumDomainGroups(
762                 self.domain_handle, resume_handle, max_size)
763             count += num_entries
764
765         self.assertEqual(count, info.num_groups)
766
767     def test_DomGeneralInformation_num_aliases(self):
768         info = self.conn.QueryDomainInfo(
769             self.domain_handle, DomainGeneralInformation)
770         #
771         # Enumerate through all the domain aliases and compare the number
772         # returned against QueryDomainInfo they should be the same
773         max_size = calc_max_size(1)
774         (resume_handle, a, num_entries) = self.conn.EnumDomainAliases(
775             self.domain_handle, 0, max_size)
776         count = num_entries
777         while resume_handle:
778             self.assertEqual(1, num_entries)
779             (resume_handle, a, num_entries) = self.conn.EnumDomainAliases(
780                 self.domain_handle, resume_handle, max_size)
781             count += num_entries
782
783         self.assertEqual(count, info.num_aliases)