torture/drs: move ExopBaseTest into DrsBaseTest and extend
[samba.git] / source4 / torture / drs / python / drs_base.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 #
4 # Unix SMB/CIFS implementation.
5 # Copyright (C) Kamen Mazdrashki <kamenim@samba.org> 2011
6 # Copyright (C) Andrew Bartlett <abartlet@samba.org> 2016
7 # Copyright (C) Catalyst IT Ltd. 2016
8 #
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22
23 import sys
24 import time
25 import os
26 import ldb
27
28 sys.path.insert(0, "bin/python")
29 import samba.tests
30 from samba.tests.samba_tool.base import SambaToolCmdTest
31 from samba import dsdb
32 from samba.dcerpc import drsuapi, misc, drsblobs, security
33 from samba.ndr import ndr_unpack, ndr_pack
34 from samba.drs_utils import drs_DsBind
35
36 from ldb import (
37     SCOPE_BASE,
38     Message,
39     FLAG_MOD_REPLACE,
40     )
41
42
43 class DrsBaseTestCase(SambaToolCmdTest):
44     """Base class implementation for all DRS python tests.
45        It is intended to provide common initialization and
46        and functionality used by all DRS tests in drs/python
47        test package. For instance, DC1 and DC2 are always used
48        to pass URLs for DCs to test against"""
49
50     def setUp(self):
51         super(DrsBaseTestCase, self).setUp()
52
53         # connect to DCs
54         url_dc = samba.tests.env_get_var_value("DC1")
55         (self.ldb_dc1, self.info_dc1) = samba.tests.connect_samdb_ex(url_dc,
56                                                                      ldap_only=True)
57         url_dc = samba.tests.env_get_var_value("DC2")
58         (self.ldb_dc2, self.info_dc2) = samba.tests.connect_samdb_ex(url_dc,
59                                                                      ldap_only=True)
60
61         # cache some of RootDSE props
62         self.schema_dn = self.info_dc1["schemaNamingContext"][0]
63         self.domain_dn = self.info_dc1["defaultNamingContext"][0]
64         self.config_dn = self.info_dc1["configurationNamingContext"][0]
65         self.forest_level = int(self.info_dc1["forestFunctionality"][0])
66
67         # we will need DCs DNS names for 'samba-tool drs' command
68         self.dnsname_dc1 = self.info_dc1["dnsHostName"][0]
69         self.dnsname_dc2 = self.info_dc2["dnsHostName"][0]
70
71     def tearDown(self):
72         super(DrsBaseTestCase, self).tearDown()
73
74     def _GUID_string(self, guid):
75         return self.ldb_dc1.schema_format_value("objectGUID", guid)
76
77     def _ldap_schemaUpdateNow(self, sam_db):
78         rec = {"dn": "",
79                "schemaUpdateNow": "1"}
80         m = Message.from_dict(sam_db, rec, FLAG_MOD_REPLACE)
81         sam_db.modify(m)
82
83     def _deleted_objects_dn(self, sam_ldb):
84         wkdn = "<WKGUID=18E2EA80684F11D2B9AA00C04F79F805,%s>" % self.domain_dn
85         res = sam_ldb.search(base=wkdn,
86                              scope=SCOPE_BASE,
87                              controls=["show_deleted:1"])
88         self.assertEquals(len(res), 1)
89         return str(res[0]["dn"])
90
91     def _lost_and_found_dn(self, sam_ldb, nc):
92         wkdn = "<WKGUID=%s,%s>" % (dsdb.DS_GUID_LOSTANDFOUND_CONTAINER, nc)
93         res = sam_ldb.search(base=wkdn,
94                              scope=SCOPE_BASE)
95         self.assertEquals(len(res), 1)
96         return str(res[0]["dn"])
97
98     def _make_obj_name(self, prefix):
99         return prefix + time.strftime("%s", time.gmtime())
100
101     def _samba_tool_cmdline(self, drs_command):
102         # find out where is net command
103         samba_tool_cmd = os.path.abspath("./bin/samba-tool")
104         # make command line credentials string
105         creds = self.get_credentials()
106         cmdline_auth = "-U%s/%s%%%s" % (creds.get_domain(),
107                                         creds.get_username(), creds.get_password())
108         # bin/samba-tool drs <drs_command> <cmdline_auth>
109         return "%s drs %s %s" % (samba_tool_cmd, drs_command, cmdline_auth)
110
111     def _net_drs_replicate(self, DC, fromDC, nc_dn=None, forced=True, local=False, full_sync=False):
112         if nc_dn is None:
113             nc_dn = self.domain_dn
114         # make base command line
115         samba_tool_cmdline = self._samba_tool_cmdline("replicate")
116         if forced:
117             samba_tool_cmdline += " --sync-forced"
118         if local:
119             samba_tool_cmdline += " --local"
120         if full_sync:
121             samba_tool_cmdline += " --full-sync"
122         # bin/samba-tool drs replicate <Dest_DC_NAME> <Src_DC_NAME> <Naming Context>
123         cmd_line = "%s %s %s %s" % (samba_tool_cmdline, DC, fromDC, nc_dn)
124         return self.check_output(cmd_line)
125
126     def _enable_inbound_repl(self, DC):
127         # make base command line
128         samba_tool_cmd = self._samba_tool_cmdline("options")
129         # disable replication
130         self.check_run("%s %s --dsa-option=-DISABLE_INBOUND_REPL" %(samba_tool_cmd, DC))
131
132     def _disable_inbound_repl(self, DC):
133         # make base command line
134         samba_tool_cmd = self._samba_tool_cmdline("options")
135         # disable replication
136         self.check_run("%s %s --dsa-option=+DISABLE_INBOUND_REPL" %(samba_tool_cmd, DC))
137
138     def _enable_all_repl(self, DC):
139         # make base command line
140         samba_tool_cmd = self._samba_tool_cmdline("options")
141         # disable replication
142         self.check_run("%s %s --dsa-option=-DISABLE_INBOUND_REPL" %(samba_tool_cmd, DC))
143         self.check_run("%s %s --dsa-option=-DISABLE_OUTBOUND_REPL" %(samba_tool_cmd, DC))
144
145     def _disable_all_repl(self, DC):
146         # make base command line
147         samba_tool_cmd = self._samba_tool_cmdline("options")
148         # disable replication
149         self.check_run("%s %s --dsa-option=+DISABLE_INBOUND_REPL" %(samba_tool_cmd, DC))
150         self.check_run("%s %s --dsa-option=+DISABLE_OUTBOUND_REPL" %(samba_tool_cmd, DC))
151
152     def _get_highest_hwm_utdv(self, ldb_conn):
153         res = ldb_conn.search("", scope=ldb.SCOPE_BASE, attrs=["highestCommittedUSN"])
154         hwm = drsuapi.DsReplicaHighWaterMark()
155         hwm.tmp_highest_usn = long(res[0]["highestCommittedUSN"][0])
156         hwm.reserved_usn = 0
157         hwm.highest_usn = hwm.tmp_highest_usn
158
159         utdv = drsuapi.DsReplicaCursorCtrEx()
160         cursors = []
161         c1 = drsuapi.DsReplicaCursor()
162         c1.source_dsa_invocation_id = misc.GUID(ldb_conn.get_invocation_id())
163         c1.highest_usn = hwm.highest_usn
164         cursors.append(c1)
165         utdv.count = len(cursors)
166         utdv.cursors = cursors
167         return (hwm, utdv)
168
169     def _get_indentifier(self, ldb_conn, dn):
170         res = ldb_conn.search(dn, scope=ldb.SCOPE_BASE,
171                 attrs=["objectGUID", "objectSid"])
172         id = drsuapi.DsReplicaObjectIdentifier()
173         id.guid = ndr_unpack(misc.GUID, res[0]['objectGUID'][0])
174         if "objectSid" in res[0]:
175             id.sid = ndr_unpack(security.dom_sid, res[0]['objectSid'][0])
176         id.dn = str(res[0].dn)
177         return id
178
179     def _check_replication(self, expected_dns, replica_flags, expected_links=[],
180                            drs_error=drsuapi.DRSUAPI_EXOP_ERR_NONE, drs=None, drs_handle=None,
181                            highwatermark=None, uptodateness_vector=None,
182                            more_flags=0, more_data=False,
183                            dn_ordered=True, links_ordered=True,
184                            max_objects=133, exop=0,
185                            dest_dsa=drsuapi.DRSUAPI_DS_BIND_GUID_W2K3,
186                            source_dsa=None, invocation_id=None, nc_dn_str=None,
187                            nc_object_count=0, nc_linked_attributes_count=0):
188         """
189         Makes sure that replication returns the specific error given.
190         """
191         if source_dsa is None:
192             source_dsa = self.ldb_dc1.get_ntds_GUID()
193         if invocation_id is None:
194             invocation_id = self.ldb_dc1.get_invocation_id()
195         if nc_dn_str is None:
196             nc_dn_str = self.ldb_dc1.domain_dn()
197
198         if highwatermark is None:
199             if self.default_hwm is None:
200                 (highwatermark, _) = self._get_highest_hwm_utdv(self.ldb_dc1)
201             else:
202                 highwatermark = self.default_hwm
203
204         if drs is None:
205             drs = self.drs
206         if drs_handle is None:
207             drs_handle = self.drs_handle
208
209         req10 = self._getnc_req10(dest_dsa=dest_dsa,
210                                   invocation_id=invocation_id,
211                                   nc_dn_str=nc_dn_str,
212                                   exop=exop,
213                                   max_objects=max_objects,
214                                   replica_flags=replica_flags)
215         req10.highwatermark = highwatermark
216         if uptodateness_vector is not None:
217             uptodateness_vector_v1 = drsuapi.DsReplicaCursorCtrEx()
218             cursors = []
219             for i in xrange(0, uptodateness_vector.count):
220                 c = uptodateness_vector.cursors[i]
221                 c1 = drsuapi.DsReplicaCursor()
222                 c1.source_dsa_invocation_id = c.source_dsa_invocation_id
223                 c1.highest_usn = c.highest_usn
224                 cursors.append(c1)
225             uptodateness_vector_v1.count = len(cursors)
226             uptodateness_vector_v1.cursors = cursors
227             req10.uptodateness_vector = uptodateness_vector_v1
228         (level, ctr) = drs.DsGetNCChanges(drs_handle, 10, req10)
229
230         self.assertEqual(level, 6, "expected level 6 response!")
231         self.assertEqual(ctr.source_dsa_guid, misc.GUID(source_dsa))
232         self.assertEqual(ctr.source_dsa_invocation_id, misc.GUID(invocation_id))
233         ctr6 = ctr
234         self.assertEqual(ctr6.extended_ret, drs_error)
235         self._check_ctr6(ctr6, expected_dns, expected_links,
236                          nc_object_count=nc_object_count)
237         return (ctr6.new_highwatermark, ctr6.uptodateness_vector)
238
239     def _check_ctr6(self, ctr6, expected_dns=[], expected_links=[],
240                     dn_ordered=True, links_ordered=True,
241                     more_data=False, nc_object_count=0,
242                     nc_linked_attributes_count=0, drs_error=0):
243         """
244         Check that a ctr6 matches the specified parameters.
245         """
246         self.assertEqual(ctr6.object_count, len(expected_dns))
247         self.assertEqual(ctr6.linked_attributes_count, len(expected_links))
248         self.assertEqual(ctr6.more_data, more_data)
249         self.assertEqual(ctr6.nc_object_count, nc_object_count)
250         self.assertEqual(ctr6.nc_linked_attributes_count, nc_linked_attributes_count)
251         self.assertEqual(ctr6.drs_error[0], drs_error)
252
253         ctr6_dns = []
254         next_object = ctr6.first_object
255         for i in range(0, ctr6.object_count):
256             ctr6_dns.append(next_object.object.identifier.dn)
257             next_object = next_object.next_object
258         self.assertEqual(next_object, None)
259
260         i = 0
261         for dn in expected_dns:
262             # Expect them back in the exact same order as specified.
263             if dn_ordered:
264                 self.assertNotEqual(ctr6_dns[i], None)
265                 self.assertEqual(ctr6_dns[i], dn)
266                 i = i + 1
267             # Don't care what order
268             else:
269                 self.assertTrue(dn in ctr6_dns, "Couldn't find DN '%s' anywhere in ctr6 response." % dn)
270
271         ctr6_links = []
272         expected_links.sort()
273         lidx = 0
274         for lidx in range(0, ctr6.linked_attributes_count):
275             l = ctr6.linked_attributes[lidx]
276             try:
277                 target = ndr_unpack(drsuapi.DsReplicaObjectIdentifier3,
278                                     l.value.blob)
279             except:
280                 target = ndr_unpack(drsuapi.DsReplicaObjectIdentifier3Binary,
281                                     l.value.blob)
282             al = AbstractLink(l.attid, l.flags,
283                               l.identifier.guid,
284                               target.guid)
285             ctr6_links.append(al)
286
287         lidx = 0
288         for el in expected_links:
289             if links_ordered:
290                 self.assertEqual(el, ctr6_links[lidx])
291                 lidx += 1
292             else:
293                 self.assertTrue(el in ctr6_links, "Couldn't find link '%s' anywhere in ctr6 response." % el)
294
295     def _exop_req8(self, dest_dsa, invocation_id, nc_dn_str, exop,
296                    replica_flags=0, max_objects=0, partial_attribute_set=None,
297                    partial_attribute_set_ex=None, mapping_ctr=None):
298         req8 = drsuapi.DsGetNCChangesRequest8()
299
300         req8.destination_dsa_guid = misc.GUID(dest_dsa) if dest_dsa else misc.GUID()
301         req8.source_dsa_invocation_id = misc.GUID(invocation_id)
302         req8.naming_context = drsuapi.DsReplicaObjectIdentifier()
303         req8.naming_context.dn = unicode(nc_dn_str)
304         req8.highwatermark = drsuapi.DsReplicaHighWaterMark()
305         req8.highwatermark.tmp_highest_usn = 0
306         req8.highwatermark.reserved_usn = 0
307         req8.highwatermark.highest_usn = 0
308         req8.uptodateness_vector = None
309         req8.replica_flags = replica_flags
310         req8.max_object_count = max_objects
311         req8.max_ndr_size = 402116
312         req8.extended_op = exop
313         req8.fsmo_info = 0
314         req8.partial_attribute_set = partial_attribute_set
315         req8.partial_attribute_set_ex = partial_attribute_set_ex
316         if mapping_ctr:
317             req8.mapping_ctr = mapping_ctr
318         else:
319             req8.mapping_ctr.num_mappings = 0
320             req8.mapping_ctr.mappings = None
321
322         return req8
323
324     def _getnc_req10(self, dest_dsa, invocation_id, nc_dn_str, exop,
325                      replica_flags=0, max_objects=0, partial_attribute_set=None,
326                      partial_attribute_set_ex=None, mapping_ctr=None,
327                      more_flags=0):
328         req10 = drsuapi.DsGetNCChangesRequest10()
329
330         req10.destination_dsa_guid = misc.GUID(dest_dsa) if dest_dsa else misc.GUID()
331         req10.source_dsa_invocation_id = misc.GUID(invocation_id)
332         req10.naming_context = drsuapi.DsReplicaObjectIdentifier()
333         req10.naming_context.dn = unicode(nc_dn_str)
334         req10.highwatermark = drsuapi.DsReplicaHighWaterMark()
335         req10.highwatermark.tmp_highest_usn = 0
336         req10.highwatermark.reserved_usn = 0
337         req10.highwatermark.highest_usn = 0
338         req10.uptodateness_vector = None
339         req10.replica_flags = replica_flags
340         req10.max_object_count = max_objects
341         req10.max_ndr_size = 402116
342         req10.extended_op = exop
343         req10.fsmo_info = 0
344         req10.partial_attribute_set = partial_attribute_set
345         req10.partial_attribute_set_ex = partial_attribute_set_ex
346         if mapping_ctr:
347             req10.mapping_ctr = mapping_ctr
348         else:
349             req10.mapping_ctr.num_mappings = 0
350             req10.mapping_ctr.mappings = None
351         req10.more_flags = more_flags
352
353         return req10
354
355     def _ds_bind(self, server_name):
356         binding_str = "ncacn_ip_tcp:%s[seal]" % server_name
357
358         drs = drsuapi.drsuapi(binding_str, self.get_loadparm(), self.get_credentials())
359         (drs_handle, supported_extensions) = drs_DsBind(drs)
360         return (drs, drs_handle)
361
362
363 class AbstractLink:
364     def __init__(self, attid, flags, identifier, targetGUID):
365         self.attid = attid
366         self.flags = flags
367         self.identifier = str(identifier)
368         self.selfGUID_blob = ndr_pack(identifier)
369         self.targetGUID = str(targetGUID)
370         self.targetGUID_blob = ndr_pack(targetGUID)
371
372     def __repr__(self):
373         return "AbstractLink(0x%08x, 0x%08x, %s, %s)" % (
374                 self.attid, self.flags, self.identifier, self.targetGUID)
375
376     def __internal_cmp__(self, other, verbose=False):
377         """See CompareLinks() in MS-DRSR section 4.1.10.5.17"""
378         if not isinstance(other, AbstractLink):
379             if verbose:
380                 print "AbstractLink.__internal_cmp__(%r, %r) => wrong type" % (self, other)
381             return NotImplemented
382
383         c = cmp(self.selfGUID_blob, other.selfGUID_blob)
384         if c != 0:
385             if verbose:
386                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different identifier" % (self, other, c)
387             return c
388
389         c = other.attid - self.attid
390         if c != 0:
391             if verbose:
392                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different attid" % (self, other, c)
393             return c
394
395         self_active = self.flags & drsuapi.DRSUAPI_DS_LINKED_ATTRIBUTE_FLAG_ACTIVE
396         other_active = other.flags & drsuapi.DRSUAPI_DS_LINKED_ATTRIBUTE_FLAG_ACTIVE
397
398         c = self_active - other_active
399         if c != 0:
400             if verbose:
401                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different FLAG_ACTIVE" % (self, other, c)
402             return c
403
404         c = cmp(self.targetGUID_blob, other.targetGUID_blob)
405         if c != 0:
406             if verbose:
407                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different target" % (self, other, c)
408             return c
409
410         c = self.flags - other.flags
411         if c != 0:
412             if verbose:
413                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different flags" % (self, other, c)
414             return c
415
416         return 0
417
418     def __lt__(self, other):
419         c = self.__internal_cmp__(other)
420         if c == NotImplemented:
421             return NotImplemented
422         if c < 0:
423             return True
424         return False
425
426     def __le__(self, other):
427         c = self.__internal_cmp__(other)
428         if c == NotImplemented:
429             return NotImplemented
430         if c <= 0:
431             return True
432         return False
433
434     def __eq__(self, other):
435         c = self.__internal_cmp__(other, verbose=True)
436         if c == NotImplemented:
437             return NotImplemented
438         if c == 0:
439             return True
440         return False
441
442     def __ne__(self, other):
443         c = self.__internal_cmp__(other)
444         if c == NotImplemented:
445             return NotImplemented
446         if c != 0:
447             return True
448         return False
449
450     def __gt__(self, other):
451         c = self.__internal_cmp__(other)
452         if c == NotImplemented:
453             return NotImplemented
454         if c > 0:
455             return True
456         return False
457
458     def __ge__(self, other):
459         c = self.__internal_cmp__(other)
460         if c == NotImplemented:
461             return NotImplemented
462         if c >= 0:
463             return True
464         return False
465
466     def __hash__(self):
467         return hash((self.attid, self.flags, self.identifier, self.targetGUID))