getnc_exop.py: Fix GET_TGT behaviour in DRS tests
[nivanova/samba-autobuild/.git] / source4 / torture / drs / python / drs_base.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 #
4 # Unix SMB/CIFS implementation.
5 # Copyright (C) Kamen Mazdrashki <kamenim@samba.org> 2011
6 # Copyright (C) Andrew Bartlett <abartlet@samba.org> 2016
7 # Copyright (C) Catalyst IT Ltd. 2016
8 #
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22
23 import sys
24 import time
25 import os
26 import ldb
27
28 sys.path.insert(0, "bin/python")
29 import samba.tests
30 from samba.tests.samba_tool.base import SambaToolCmdTest
31 from samba import dsdb
32 from samba.dcerpc import drsuapi, misc, drsblobs, security
33 from samba.ndr import ndr_unpack, ndr_pack
34 from samba.drs_utils import drs_DsBind
35 from samba import gensec
36 from ldb import (
37     SCOPE_BASE,
38     Message,
39     FLAG_MOD_REPLACE,
40     )
41
42
43 class DrsBaseTestCase(SambaToolCmdTest):
44     """Base class implementation for all DRS python tests.
45        It is intended to provide common initialization and
46        and functionality used by all DRS tests in drs/python
47        test package. For instance, DC1 and DC2 are always used
48        to pass URLs for DCs to test against"""
49
50     def setUp(self):
51         super(DrsBaseTestCase, self).setUp()
52         creds = self.get_credentials()
53         creds.set_gensec_features(creds.get_gensec_features() | gensec.FEATURE_SEAL)
54
55         # connect to DCs
56         url_dc = samba.tests.env_get_var_value("DC1")
57         (self.ldb_dc1, self.info_dc1) = samba.tests.connect_samdb_ex(url_dc,
58                                                                      ldap_only=True)
59         url_dc = samba.tests.env_get_var_value("DC2")
60         (self.ldb_dc2, self.info_dc2) = samba.tests.connect_samdb_ex(url_dc,
61                                                                      ldap_only=True)
62
63         # cache some of RootDSE props
64         self.schema_dn = self.info_dc1["schemaNamingContext"][0]
65         self.domain_dn = self.info_dc1["defaultNamingContext"][0]
66         self.config_dn = self.info_dc1["configurationNamingContext"][0]
67         self.forest_level = int(self.info_dc1["forestFunctionality"][0])
68
69         # we will need DCs DNS names for 'samba-tool drs' command
70         self.dnsname_dc1 = self.info_dc1["dnsHostName"][0]
71         self.dnsname_dc2 = self.info_dc2["dnsHostName"][0]
72
73         # for debugging the test code
74         self._debug = False
75
76     def tearDown(self):
77         super(DrsBaseTestCase, self).tearDown()
78
79     def _GUID_string(self, guid):
80         return self.ldb_dc1.schema_format_value("objectGUID", guid)
81
82     def _ldap_schemaUpdateNow(self, sam_db):
83         rec = {"dn": "",
84                "schemaUpdateNow": "1"}
85         m = Message.from_dict(sam_db, rec, FLAG_MOD_REPLACE)
86         sam_db.modify(m)
87
88     def _deleted_objects_dn(self, sam_ldb):
89         wkdn = "<WKGUID=18E2EA80684F11D2B9AA00C04F79F805,%s>" % self.domain_dn
90         res = sam_ldb.search(base=wkdn,
91                              scope=SCOPE_BASE,
92                              controls=["show_deleted:1"])
93         self.assertEquals(len(res), 1)
94         return str(res[0]["dn"])
95
96     def _lost_and_found_dn(self, sam_ldb, nc):
97         wkdn = "<WKGUID=%s,%s>" % (dsdb.DS_GUID_LOSTANDFOUND_CONTAINER, nc)
98         res = sam_ldb.search(base=wkdn,
99                              scope=SCOPE_BASE)
100         self.assertEquals(len(res), 1)
101         return str(res[0]["dn"])
102
103     def _make_obj_name(self, prefix):
104         return prefix + time.strftime("%s", time.gmtime())
105
106     def _samba_tool_cmd_list(self, drs_command):
107         # make command line credentials string
108
109         ccache_name = self.get_creds_ccache_name()
110
111         # Tunnel the command line credentials down to the
112         # subcommand to avoid a new kinit
113         cmdline_auth = "--krb5-ccache=%s" % ccache_name
114
115         # bin/samba-tool drs <drs_command> <cmdline_auth>
116         return ["drs", drs_command, cmdline_auth]
117
118     def _net_drs_replicate(self, DC, fromDC, nc_dn=None, forced=True, local=False, full_sync=False):
119         if nc_dn is None:
120             nc_dn = self.domain_dn
121         # make base command line
122         samba_tool_cmdline = self._samba_tool_cmd_list("replicate")
123         # bin/samba-tool drs replicate <Dest_DC_NAME> <Src_DC_NAME> <Naming Context>
124         samba_tool_cmdline += [DC, fromDC, nc_dn]
125
126         if forced:
127             samba_tool_cmdline += ["--sync-forced"]
128         if local:
129             samba_tool_cmdline += ["--local"]
130         if full_sync:
131             samba_tool_cmdline += ["--full-sync"]
132
133         (result, out, err) = self.runsubcmd(*samba_tool_cmdline)
134         self.assertCmdSuccess(result, out, err)
135         self.assertEquals(err,"","Shouldn't be any error messages")
136
137     def _enable_inbound_repl(self, DC):
138         # make base command line
139         samba_tool_cmd = self._samba_tool_cmd_list("options")
140         # disable replication
141         samba_tool_cmd += [DC, "--dsa-option=-DISABLE_INBOUND_REPL"]
142         (result, out, err) = self.runsubcmd(*samba_tool_cmd)
143         self.assertCmdSuccess(result, out, err)
144         self.assertEquals(err,"","Shouldn't be any error messages")
145
146     def _disable_inbound_repl(self, DC):
147         # make base command line
148         samba_tool_cmd = self._samba_tool_cmd_list("options")
149         # disable replication
150         samba_tool_cmd += [DC, "--dsa-option=+DISABLE_INBOUND_REPL"]
151         (result, out, err) = self.runsubcmd(*samba_tool_cmd)
152         self.assertCmdSuccess(result, out, err)
153         self.assertEquals(err,"","Shouldn't be any error messages")
154
155     def _enable_all_repl(self, DC):
156         self._enable_inbound_repl(DC)
157         # make base command line
158         samba_tool_cmd = self._samba_tool_cmd_list("options")
159         # enable replication
160         samba_tool_cmd += [DC, "--dsa-option=-DISABLE_OUTBOUND_REPL"]
161         (result, out, err) = self.runsubcmd(*samba_tool_cmd)
162         self.assertCmdSuccess(result, out, err)
163         self.assertEquals(err,"","Shouldn't be any error messages")
164
165     def _disable_all_repl(self, DC):
166         self._disable_inbound_repl(DC)
167         # make base command line
168         samba_tool_cmd = self._samba_tool_cmd_list("options")
169         # disable replication
170         samba_tool_cmd += [DC, "--dsa-option=+DISABLE_OUTBOUND_REPL"]
171         (result, out, err) = self.runsubcmd(*samba_tool_cmd)
172         self.assertCmdSuccess(result, out, err)
173         self.assertEquals(err,"","Shouldn't be any error messages")
174
175     def _get_highest_hwm_utdv(self, ldb_conn):
176         res = ldb_conn.search("", scope=ldb.SCOPE_BASE, attrs=["highestCommittedUSN"])
177         hwm = drsuapi.DsReplicaHighWaterMark()
178         hwm.tmp_highest_usn = long(res[0]["highestCommittedUSN"][0])
179         hwm.reserved_usn = 0
180         hwm.highest_usn = hwm.tmp_highest_usn
181
182         utdv = drsuapi.DsReplicaCursorCtrEx()
183         cursors = []
184         c1 = drsuapi.DsReplicaCursor()
185         c1.source_dsa_invocation_id = misc.GUID(ldb_conn.get_invocation_id())
186         c1.highest_usn = hwm.highest_usn
187         cursors.append(c1)
188         utdv.count = len(cursors)
189         utdv.cursors = cursors
190         return (hwm, utdv)
191
192     def _get_identifier(self, ldb_conn, dn):
193         res = ldb_conn.search(dn, scope=ldb.SCOPE_BASE,
194                 attrs=["objectGUID", "objectSid"])
195         id = drsuapi.DsReplicaObjectIdentifier()
196         id.guid = ndr_unpack(misc.GUID, res[0]['objectGUID'][0])
197         if "objectSid" in res[0]:
198             id.sid = ndr_unpack(security.dom_sid, res[0]['objectSid'][0])
199         id.dn = str(res[0].dn)
200         return id
201
202     def _get_ctr6_links(self, ctr6):
203         """
204         Unpacks the linked attributes from a DsGetNCChanges response
205         and returns them as a list.
206         """
207         ctr6_links = []
208         for lidx in range(0, ctr6.linked_attributes_count):
209             l = ctr6.linked_attributes[lidx]
210             try:
211                 target = ndr_unpack(drsuapi.DsReplicaObjectIdentifier3,
212                                     l.value.blob)
213             except:
214                 target = ndr_unpack(drsuapi.DsReplicaObjectIdentifier3Binary,
215                                     l.value.blob)
216             al = AbstractLink(l.attid, l.flags,
217                               l.identifier.guid,
218                               target.guid, target.dn)
219             ctr6_links.append(al)
220
221         return ctr6_links
222
223     def _ctr6_debug(self, ctr6):
224         """
225         Displays basic info contained in a DsGetNCChanges response.
226         Having this debug code allows us to see the difference in behaviour
227         between Samba and Windows easier. Turn on the self._debug flag to see it.
228         """
229
230         if self._debug:
231             print("------------ recvd CTR6 -------------")
232
233             next_object = ctr6.first_object
234             for i in range(0, ctr6.object_count):
235                 print("Obj %d: %s %s" %(i, next_object.object.identifier.dn[:22],
236                                         next_object.object.identifier.guid))
237                 next_object = next_object.next_object
238
239             print("Linked Attributes: %d" % ctr6.linked_attributes_count)
240             ctr6_links = self._get_ctr6_links(ctr6)
241             for link in ctr6_links:
242                 print("Link Tgt %s... <-- Src %s"
243                       %(link.targetDN[:22], link.identifier))
244
245             print("HWM:     %d" %(ctr6.new_highwatermark.highest_usn))
246             print("Tmp HWM: %d" %(ctr6.new_highwatermark.tmp_highest_usn))
247             print("More data: %d" %(ctr6.more_data))
248
249     def _get_replication(self, replica_flags,
250                           drs_error=drsuapi.DRSUAPI_EXOP_ERR_NONE, drs=None, drs_handle=None,
251                           highwatermark=None, uptodateness_vector=None,
252                           more_flags=0, max_objects=133, exop=0,
253                           dest_dsa=drsuapi.DRSUAPI_DS_BIND_GUID_W2K3,
254                           source_dsa=None, invocation_id=None, nc_dn_str=None):
255         """
256         Builds a DsGetNCChanges request based on the information provided
257         and returns the response received from the DC.
258         """
259         if source_dsa is None:
260             source_dsa = self.ldb_dc1.get_ntds_GUID()
261         if invocation_id is None:
262             invocation_id = self.ldb_dc1.get_invocation_id()
263         if nc_dn_str is None:
264             nc_dn_str = self.ldb_dc1.domain_dn()
265
266         if highwatermark is None:
267             if self.default_hwm is None:
268                 (highwatermark, _) = self._get_highest_hwm_utdv(self.ldb_dc1)
269             else:
270                 highwatermark = self.default_hwm
271
272         if drs is None:
273             drs = self.drs
274         if drs_handle is None:
275             drs_handle = self.drs_handle
276
277         req10 = self._getnc_req10(dest_dsa=dest_dsa,
278                                   invocation_id=invocation_id,
279                                   nc_dn_str=nc_dn_str,
280                                   exop=exop,
281                                   max_objects=max_objects,
282                                   replica_flags=replica_flags,
283                                   more_flags=more_flags)
284         req10.highwatermark = highwatermark
285         if uptodateness_vector is not None:
286             uptodateness_vector_v1 = drsuapi.DsReplicaCursorCtrEx()
287             cursors = []
288             for i in xrange(0, uptodateness_vector.count):
289                 c = uptodateness_vector.cursors[i]
290                 c1 = drsuapi.DsReplicaCursor()
291                 c1.source_dsa_invocation_id = c.source_dsa_invocation_id
292                 c1.highest_usn = c.highest_usn
293                 cursors.append(c1)
294             uptodateness_vector_v1.count = len(cursors)
295             uptodateness_vector_v1.cursors = cursors
296             req10.uptodateness_vector = uptodateness_vector_v1
297         (level, ctr) = drs.DsGetNCChanges(drs_handle, 10, req10)
298         self._ctr6_debug(ctr)
299
300         self.assertEqual(level, 6, "expected level 6 response!")
301         self.assertEqual(ctr.source_dsa_guid, misc.GUID(source_dsa))
302         self.assertEqual(ctr.source_dsa_invocation_id, misc.GUID(invocation_id))
303         self.assertEqual(ctr.extended_ret, drs_error)
304
305         return ctr
306
307     def _check_replication(self, expected_dns, replica_flags, expected_links=[],
308                            drs_error=drsuapi.DRSUAPI_EXOP_ERR_NONE, drs=None, drs_handle=None,
309                            highwatermark=None, uptodateness_vector=None,
310                            more_flags=0, more_data=False,
311                            dn_ordered=True, links_ordered=True,
312                            max_objects=133, exop=0,
313                            dest_dsa=drsuapi.DRSUAPI_DS_BIND_GUID_W2K3,
314                            source_dsa=None, invocation_id=None, nc_dn_str=None,
315                            nc_object_count=0, nc_linked_attributes_count=0):
316         """
317         Makes sure that replication returns the specific error given.
318         """
319
320         # send a DsGetNCChanges to the DC
321         ctr6 = self._get_replication(replica_flags,
322                                      drs_error, drs, drs_handle,
323                                      highwatermark, uptodateness_vector,
324                                      more_flags, max_objects, exop, dest_dsa,
325                                      source_dsa, invocation_id, nc_dn_str)
326
327         # check the response is what we expect
328         self._check_ctr6(ctr6, expected_dns, expected_links,
329                          nc_object_count=nc_object_count, more_data=more_data,
330                          dn_ordered=dn_ordered)
331         return (ctr6.new_highwatermark, ctr6.uptodateness_vector)
332
333
334     def _get_ctr6_dn_list(self, ctr6):
335         """
336         Returns the DNs contained in a DsGetNCChanges response.
337         """
338         dn_list = []
339         next_object = ctr6.first_object
340         for i in range(0, ctr6.object_count):
341             dn_list.append(next_object.object.identifier.dn)
342             next_object = next_object.next_object
343         self.assertEqual(next_object, None)
344
345         return dn_list
346
347
348     def _check_ctr6(self, ctr6, expected_dns=[], expected_links=[],
349                     dn_ordered=True, links_ordered=True,
350                     more_data=False, nc_object_count=0,
351                     nc_linked_attributes_count=0, drs_error=0):
352         """
353         Check that a ctr6 matches the specified parameters.
354         """
355         self.assertEqual(ctr6.object_count, len(expected_dns))
356         self.assertEqual(ctr6.linked_attributes_count, len(expected_links))
357         self.assertEqual(ctr6.more_data, more_data)
358         self.assertEqual(ctr6.nc_object_count, nc_object_count)
359         self.assertEqual(ctr6.nc_linked_attributes_count, nc_linked_attributes_count)
360         self.assertEqual(ctr6.drs_error[0], drs_error)
361
362         ctr6_dns = self._get_ctr6_dn_list(ctr6)
363
364         i = 0
365         for dn in expected_dns:
366             # Expect them back in the exact same order as specified.
367             if dn_ordered:
368                 self.assertNotEqual(ctr6_dns[i], None)
369                 self.assertEqual(ctr6_dns[i], dn)
370                 i = i + 1
371             # Don't care what order
372             else:
373                 self.assertTrue(dn in ctr6_dns, "Couldn't find DN '%s' anywhere in ctr6 response." % dn)
374
375         # Extract the links from the response
376         ctr6_links = self._get_ctr6_links(ctr6)
377         expected_links.sort()
378
379         lidx = 0
380         for el in expected_links:
381             if links_ordered:
382                 self.assertEqual(el, ctr6_links[lidx])
383                 lidx += 1
384             else:
385                 self.assertTrue(el in ctr6_links, "Couldn't find link '%s' anywhere in ctr6 response." % el)
386
387     def _exop_req8(self, dest_dsa, invocation_id, nc_dn_str, exop,
388                    replica_flags=0, max_objects=0, partial_attribute_set=None,
389                    partial_attribute_set_ex=None, mapping_ctr=None):
390         req8 = drsuapi.DsGetNCChangesRequest8()
391
392         req8.destination_dsa_guid = misc.GUID(dest_dsa) if dest_dsa else misc.GUID()
393         req8.source_dsa_invocation_id = misc.GUID(invocation_id)
394         req8.naming_context = drsuapi.DsReplicaObjectIdentifier()
395         req8.naming_context.dn = unicode(nc_dn_str)
396         req8.highwatermark = drsuapi.DsReplicaHighWaterMark()
397         req8.highwatermark.tmp_highest_usn = 0
398         req8.highwatermark.reserved_usn = 0
399         req8.highwatermark.highest_usn = 0
400         req8.uptodateness_vector = None
401         req8.replica_flags = replica_flags
402         req8.max_object_count = max_objects
403         req8.max_ndr_size = 402116
404         req8.extended_op = exop
405         req8.fsmo_info = 0
406         req8.partial_attribute_set = partial_attribute_set
407         req8.partial_attribute_set_ex = partial_attribute_set_ex
408         if mapping_ctr:
409             req8.mapping_ctr = mapping_ctr
410         else:
411             req8.mapping_ctr.num_mappings = 0
412             req8.mapping_ctr.mappings = None
413
414         return req8
415
416     def _getnc_req10(self, dest_dsa, invocation_id, nc_dn_str, exop,
417                      replica_flags=0, max_objects=0, partial_attribute_set=None,
418                      partial_attribute_set_ex=None, mapping_ctr=None,
419                      more_flags=0):
420         req10 = drsuapi.DsGetNCChangesRequest10()
421
422         req10.destination_dsa_guid = misc.GUID(dest_dsa) if dest_dsa else misc.GUID()
423         req10.source_dsa_invocation_id = misc.GUID(invocation_id)
424         req10.naming_context = drsuapi.DsReplicaObjectIdentifier()
425         req10.naming_context.dn = unicode(nc_dn_str)
426         req10.highwatermark = drsuapi.DsReplicaHighWaterMark()
427         req10.highwatermark.tmp_highest_usn = 0
428         req10.highwatermark.reserved_usn = 0
429         req10.highwatermark.highest_usn = 0
430         req10.uptodateness_vector = None
431         req10.replica_flags = replica_flags
432         req10.max_object_count = max_objects
433         req10.max_ndr_size = 402116
434         req10.extended_op = exop
435         req10.fsmo_info = 0
436         req10.partial_attribute_set = partial_attribute_set
437         req10.partial_attribute_set_ex = partial_attribute_set_ex
438         if mapping_ctr:
439             req10.mapping_ctr = mapping_ctr
440         else:
441             req10.mapping_ctr.num_mappings = 0
442             req10.mapping_ctr.mappings = None
443         req10.more_flags = more_flags
444
445         return req10
446
447     def _ds_bind(self, server_name, creds=None):
448         binding_str = "ncacn_ip_tcp:%s[seal]" % server_name
449
450         if creds is None:
451             creds = self.get_credentials()
452         drs = drsuapi.drsuapi(binding_str, self.get_loadparm(), creds)
453         (drs_handle, supported_extensions) = drs_DsBind(drs)
454         return (drs, drs_handle)
455
456     def get_partial_attribute_set(self, attids=[drsuapi.DRSUAPI_ATTID_objectClass]):
457         partial_attribute_set = drsuapi.DsPartialAttributeSet()
458         partial_attribute_set.attids = attids
459         partial_attribute_set.num_attids = len(attids)
460         return partial_attribute_set
461
462
463
464 class AbstractLink:
465     def __init__(self, attid, flags, identifier, targetGUID,
466                  targetDN=""):
467         self.attid = attid
468         self.flags = flags
469         self.identifier = str(identifier)
470         self.selfGUID_blob = ndr_pack(identifier)
471         self.targetGUID = str(targetGUID)
472         self.targetGUID_blob = ndr_pack(targetGUID)
473         self.targetDN = targetDN
474
475     def __repr__(self):
476         return "AbstractLink(0x%08x, 0x%08x, %s, %s)" % (
477                 self.attid, self.flags, self.identifier, self.targetGUID)
478
479     def __internal_cmp__(self, other, verbose=False):
480         """See CompareLinks() in MS-DRSR section 4.1.10.5.17"""
481         if not isinstance(other, AbstractLink):
482             if verbose:
483                 print "AbstractLink.__internal_cmp__(%r, %r) => wrong type" % (self, other)
484             return NotImplemented
485
486         c = cmp(self.selfGUID_blob, other.selfGUID_blob)
487         if c != 0:
488             if verbose:
489                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different identifier" % (self, other, c)
490             return c
491
492         c = other.attid - self.attid
493         if c != 0:
494             if verbose:
495                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different attid" % (self, other, c)
496             return c
497
498         self_active = self.flags & drsuapi.DRSUAPI_DS_LINKED_ATTRIBUTE_FLAG_ACTIVE
499         other_active = other.flags & drsuapi.DRSUAPI_DS_LINKED_ATTRIBUTE_FLAG_ACTIVE
500
501         c = self_active - other_active
502         if c != 0:
503             if verbose:
504                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different FLAG_ACTIVE" % (self, other, c)
505             return c
506
507         c = cmp(self.targetGUID_blob, other.targetGUID_blob)
508         if c != 0:
509             if verbose:
510                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different target" % (self, other, c)
511             return c
512
513         c = self.flags - other.flags
514         if c != 0:
515             if verbose:
516                 print "AbstractLink.__internal_cmp__(%r, %r) => %d different flags" % (self, other, c)
517             return c
518
519         return 0
520
521     def __lt__(self, other):
522         c = self.__internal_cmp__(other)
523         if c == NotImplemented:
524             return NotImplemented
525         if c < 0:
526             return True
527         return False
528
529     def __le__(self, other):
530         c = self.__internal_cmp__(other)
531         if c == NotImplemented:
532             return NotImplemented
533         if c <= 0:
534             return True
535         return False
536
537     def __eq__(self, other):
538         c = self.__internal_cmp__(other, verbose=True)
539         if c == NotImplemented:
540             return NotImplemented
541         if c == 0:
542             return True
543         return False
544
545     def __ne__(self, other):
546         c = self.__internal_cmp__(other)
547         if c == NotImplemented:
548             return NotImplemented
549         if c != 0:
550             return True
551         return False
552
553     def __gt__(self, other):
554         c = self.__internal_cmp__(other)
555         if c == NotImplemented:
556             return NotImplemented
557         if c > 0:
558             return True
559         return False
560
561     def __ge__(self, other):
562         c = self.__internal_cmp__(other)
563         if c == NotImplemented:
564             return NotImplemented
565         if c >= 0:
566             return True
567         return False
568
569     def __hash__(self):
570         return hash((self.attid, self.flags, self.identifier, self.targetGUID))