Correct "overriden" typos.
[amitay/samba.git] / python / samba / tests / dcerpc / dnsserver.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Amitay Isaacs <amitay@gmail.com> 2011
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 """Tests for samba.dcerpc.dnsserver"""
19
20 import os
21 import ldb
22
23 from samba.auth import system_session
24 from samba.samdb import SamDB
25 from samba.ndr import ndr_unpack, ndr_pack
26 from samba.dcerpc import dnsp, dnsserver, security
27 from samba.tests import RpcInterfaceTestCase, env_get_var_value
28 from samba.netcmd.dns import ARecord, AAAARecord, PTRRecord, CNameRecord, NSRecord, MXRecord, SRVRecord, TXTRecord
29 from samba import sd_utils, descriptor
30
31 class DnsserverTests(RpcInterfaceTestCase):
32
33     @classmethod
34     def setUpClass(cls):
35         good_dns = ["SAMDOM.EXAMPLE.COM",
36                     "1.EXAMPLE.COM",
37                     "%sEXAMPLE.COM" % ("1."*100),
38                     "EXAMPLE",
39                     "\n.COM",
40                     "!@#$%^&*()_",
41                     "HIGH\xFFBYTE",
42                     "@.EXAMPLE.COM",
43                     "."]
44         bad_dns = ["...",
45                    ".EXAMPLE.COM",
46                    ".EXAMPLE.",
47                    "",
48                    "SAMDOM..EXAMPLE.COM"]
49
50         good_mx = ["SAMDOM.EXAMPLE.COM 65535"]
51         bad_mx = []
52
53         good_srv = ["SAMDOM.EXAMPLE.COM 65535 65535 65535"]
54         bad_srv = []
55
56         for bad_dn in bad_dns:
57             bad_mx.append("%s 1" % bad_dn)
58             bad_srv.append("%s 0 0 0" % bad_dn)
59         for good_dn in good_dns:
60             good_mx.append("%s 1" % good_dn)
61             good_srv.append("%s 0 0 0" % good_dn)
62
63         cls.good_records = {
64             "A": ["192.168.0.1",
65                   "255.255.255.255"],
66             "AAAA": ["1234:5678:9ABC:DEF0:0000:0000:0000:0000",
67                      "0000:0000:0000:0000:0000:0000:0000:0000",
68                      "1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0",
69                      "1234:1234:1234::",
70                      "1234:1234:1234:1234:1234::",
71                      "1234:5678:9ABC:DEF0::",
72                      "0000:0000::0000",
73                      "1234::5678:9ABC:0000:0000:0000:0000",
74                      "::1",
75                      "::",
76                      "1:1:1:1:1:1:1:1"],
77             "PTR": good_dns,
78             "CNAME": good_dns,
79             "NS": good_dns,
80             "MX": good_mx,
81             "SRV": good_srv,
82             "TXT": ["text", "", "@#!", "\n"]
83         }
84
85         cls.bad_records = {
86             "A": ["192.168.0.500",
87                   "255.255.255.255/32"],
88             "AAAA": ["GGGG:1234:5678:9ABC:0000:0000:0000:0000",
89                      "0000:0000:0000:0000:0000:0000:0000:0000/1",
90                      "AAAA:AAAA:AAAA:AAAA:G000:0000:0000:1234",
91                      "1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234",
92                      "1234:5678:9ABC:DEF0:1234:5678:9ABC",
93                      "1111::1111::1111"],
94             "PTR": bad_dns,
95             "CNAME": bad_dns,
96             "NS": bad_dns,
97             "MX": bad_mx,
98             "SRV": bad_srv
99         }
100
101         # Because we use uint16_t for these numbers, we can't
102         # actually create these records.
103         invalid_mx = ["SAMDOM.EXAMPLE.COM -1",
104                       "SAMDOM.EXAMPLE.COM 65536",
105                       "%s 1" % "A"*256]
106         invalid_srv = ["SAMDOM.EXAMPLE.COM 0 65536 0",
107                        "SAMDOM.EXAMPLE.COM 0 0 65536",
108                        "SAMDOM.EXAMPLE.COM 65536 0 0"]
109         cls.invalid_records = {
110             "MX": invalid_mx,
111             "SRV": invalid_srv
112         }
113
114     def setUp(self):
115         super(DnsserverTests, self).setUp()
116         self.server = os.environ["DC_SERVER"]
117         self.zone = env_get_var_value("REALM").lower()
118         self.conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server),
119                                         self.get_loadparm(),
120                                         self.get_credentials())
121
122         self.samdb = SamDB(url="ldap://%s" % os.environ["DC_SERVER_IP"],
123                            lp = self.get_loadparm(),
124                            session_info=system_session(),
125                            credentials=self.get_credentials())
126
127
128         self.custom_zone = "zone"
129         zone_create_info = dnsserver.DNS_RPC_ZONE_CREATE_INFO_LONGHORN()
130         zone_create_info.pszZoneName = self.custom_zone
131         zone_create_info.dwZoneType = dnsp.DNS_ZONE_TYPE_PRIMARY
132         zone_create_info.fAging = 0
133         zone_create_info.fDsIntegrated = 1
134         zone_create_info.fLoadExisting = 1
135         zone_create_info.dwDpFlags = dnsserver.DNS_DP_DOMAIN_DEFAULT
136
137         self.conn.DnssrvOperation2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
138                                    0,
139                                    self.server,
140                                    None,
141                                    0,
142                                    'ZoneCreate',
143                                    dnsserver.DNSSRV_TYPEID_ZONE_CREATE,
144                                    zone_create_info)
145
146     def tearDown(self):
147         self.conn.DnssrvOperation2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
148                                    0,
149                                    self.server,
150                                    self.custom_zone,
151                                    0,
152                                    'DeleteZoneFromDs',
153                                    dnsserver.DNSSRV_TYPEID_NULL,
154                                    None)
155         super(DnsserverTests, self).tearDown()
156
157     # This test fails against Samba (but passes against Windows),
158     # because Samba does not return the record when we enum records.
159     # Records can be given DNS_RANK_NONE when the zone they are in
160     # does not have DNS_ZONE_TYPE_PRIMARY. Since such records can be
161     # deleted, however, we do not consider this urgent to fix and
162     # so this test is a knownfail.
163     def test_rank_none(self):
164         """
165         See what happens when we set a record's rank to
166         DNS_RANK_NONE.
167         """
168
169         record_str = "192.168.50.50"
170         record_type_str = "A"
171         self.add_record(self.custom_zone, "testrecord", record_type_str, record_str)
172
173         dn, record = self.get_record_from_db(self.custom_zone, "testrecord")
174         record.rank = 0 # DNS_RANK_NONE
175         res = self.samdb.dns_replace_by_dn(dn, [record])
176         if res is not None:
177             self.fail("Unable to update dns record to have DNS_RANK_NONE.")
178
179         self.assert_num_records(self.custom_zone, "testrecord", record_type_str)
180         self.add_record(self.custom_zone, "testrecord", record_type_str, record_str, assertion=False)
181         self.delete_record(self.custom_zone, "testrecord", record_type_str, record_str)
182         self.assert_num_records(self.custom_zone, "testrecord", record_type_str, 0)
183
184     def test_dns_tombstoned(self):
185         """
186         See what happens when we set a record to be tombstoned.
187         """
188
189         record_str = "192.168.50.50"
190         record_type_str = "A"
191         self.add_record(self.custom_zone, "testrecord", record_type_str, record_str)
192
193         dn, record = self.get_record_from_db(self.custom_zone, "testrecord")
194         record.wType = dnsp.DNS_TYPE_TOMBSTONE
195         res = self.samdb.dns_replace_by_dn(dn, [record])
196         if res is not None:
197             self.fail("Unable to update dns record to be tombstoned.")
198
199         self.assert_num_records(self.custom_zone, "testrecord", record_type_str)
200         self.delete_record(self.custom_zone, "testrecord", record_type_str, record_str)
201         self.assert_num_records(self.custom_zone, "testrecord", record_type_str, 0)
202
203     def get_record_from_db(self, zone_name, record_name):
204         """
205         Returns (dn of record, record)
206         """
207
208         zones = self.samdb.search(base="DC=DomainDnsZones,%s" % self.samdb.get_default_basedn(), scope=ldb.SCOPE_SUBTREE,
209                                   expression="(objectClass=dnsZone)",
210                                   attrs=["cn"])
211
212         zone_dn = None
213         for zone in zones:
214             if zone_name in str(zone.dn):
215                 zone_dn = zone.dn
216                 break
217
218         if zone_dn is None:
219             raise AssertionError("Couldn't find zone '%s'." % zone_name)
220
221         records = self.samdb.search(base=zone_dn, scope=ldb.SCOPE_SUBTREE,
222                                     expression="(objectClass=dnsNode)",
223                                     attrs=["dnsRecord"])
224
225         for old_packed_record in records:
226             if record_name in str(old_packed_record.dn):
227                 return (old_packed_record.dn, ndr_unpack(dnsp.DnssrvRpcRecord, old_packed_record["dnsRecord"][0]))
228
229     def test_duplicate_matching(self):
230         """
231         Make sure that records which should be distinct from each other or duplicate
232         to each other behave as expected.
233         """
234
235         distinct_dns = [("SAMDOM.EXAMPLE.COM",
236                          "SAMDOM.EXAMPLE.CO",
237                          "EXAMPLE.COM", "SAMDOM.EXAMPLE")]
238         duplicate_dns = [("SAMDOM.EXAMPLE.COM", "samdom.example.com", "SAMDOM.example.COM"),
239                          ("EXAMPLE.", "EXAMPLE")]
240
241         # Every tuple has entries which should be considered duplicate to one another.
242         duplicates = {
243             "AAAA": [("AAAA::", "aaaa::"),
244                      ("AAAA::", "AAAA:0000::"),
245                      ("AAAA::", "AAAA:0000:0000:0000:0000:0000:0000:0000"),
246                      ("AAAA::", "AAAA:0:0:0:0:0:0:0"),
247                      ("0123::", "123::"),
248                      ("::", "::0", "0000:0000:0000:0000:0000:0000:0000:0000")],
249         }
250
251         # Every tuple has entries which should be considered distinct from one another.
252         distinct = {
253             "A": [("192.168.1.0", "192.168.1.1", "192.168.2.0", "192.169.1.0", "193.168.1.0")],
254             "AAAA": [("AAAA::1234:5678:9ABC", "::AAAA:1234:5678:9ABC"),
255                      ("1000::", "::1000"),
256                      ("::1", "::11", "::1111"),
257                      ("1234::", "0234::")],
258             "SRV": [("SAMDOM.EXAMPLE.COM 1 1 1", "SAMDOM.EXAMPLE.COM 1 1 0", "SAMDOM.EXAMPLE.COM 1 0 1",
259                      "SAMDOM.EXAMPLE.COM 0 1 1", "SAMDOM.EXAMPLE.COM 2 1 0", "SAMDOM.EXAMPLE.COM 2 2 2")],
260             "MX": [("SAMDOM.EXAMPLE.COM 1", "SAMDOM.EXAMPLE.COM 0")],
261             "TXT": [("A RECORD", "B RECORD", "a record")]
262         }
263
264         for record_type_str in ("PTR", "CNAME", "NS"):
265             distinct[record_type_str] = distinct_dns
266             duplicates[record_type_str] = duplicate_dns
267
268         for record_type_str in duplicates:
269             for duplicate_tuple in duplicates[record_type_str]:
270                 # Attempt to add duplicates and make sure that all after the first fails
271                 self.add_record(self.custom_zone, "testrecord", record_type_str, duplicate_tuple[0])
272                 for record in duplicate_tuple:
273                     self.add_record(self.custom_zone, "testrecord", record_type_str, record, assertion=False)
274                     self.assert_num_records(self.custom_zone, "testrecord", record_type_str)
275                 self.delete_record(self.custom_zone, "testrecord", record_type_str, duplicate_tuple[0])
276
277                 # Repeatedly: add the first duplicate, and attempt to remove all of the others, making sure this succeeds
278                 for record in duplicate_tuple:
279                     self.add_record(self.custom_zone, "testrecord", record_type_str, duplicate_tuple[0])
280                     self.delete_record(self.custom_zone, "testrecord", record_type_str, record)
281
282         for record_type_str in distinct:
283             for distinct_tuple in distinct[record_type_str]:
284                 # Attempt to add distinct and make sure that they all succeed within a tuple
285                 i = 0
286                 for record in distinct_tuple:
287                     i = i + 1
288                     try:
289                         self.add_record(self.custom_zone, "testrecord", record_type_str, record)
290                         # All records should have been added.
291                         self.assert_num_records(self.custom_zone, "testrecord", record_type_str, expected_num=i)
292                     except AssertionError as e:
293                         raise AssertionError("Failed to add %s, which should be distinct from all others in the set. "
294                                              "Original error: %s\nDistinct set: %s." % (record, e, distinct_tuple))
295                 for record in distinct_tuple:
296                     self.delete_record(self.custom_zone, "testrecord", record_type_str, record)
297                     # CNAMEs should not have been added, since they conflict.
298                     if record_type_str == 'CNAME':
299                         continue
300
301                 # Add the first distinct and attempt to remove all of the others, making sure this fails
302                 # Windows fails this test. This is probably due to weird tombstoning behavior.
303                 self.add_record(self.custom_zone, "testrecord", record_type_str, distinct_tuple[0])
304                 for record in distinct_tuple:
305                     if record == distinct_tuple[0]:
306                         continue
307                     try:
308                         self.delete_record(self.custom_zone, "testrecord", record_type_str, record, assertion=False)
309                     except AssertionError as e:
310                         raise AssertionError("Managed to remove %s by attempting to remove %s. Original error: %s"
311                                              % (distinct_tuple[0], record, e))
312                 self.delete_record(self.custom_zone, "testrecord", record_type_str, distinct_tuple[0])
313
314     def test_accept_valid_commands(self):
315         """
316         Make sure that we can add, update and delete a variety
317         of valid records.
318         """
319         for record_type_str in self.good_records:
320             for record_str in self.good_records[record_type_str]:
321                 self.add_record(self.custom_zone, "testrecord", record_type_str, record_str)
322                 self.assert_num_records(self.custom_zone, "testrecord", record_type_str)
323                 self.delete_record(self.custom_zone, "testrecord", record_type_str, record_str)
324
325     def test_reject_invalid_commands(self):
326         """
327         Make sure that we can't add a variety of invalid records,
328         and that we can't update valid records to invalid ones.
329         """
330         num_failures = 0
331         for record_type_str in self.bad_records:
332             for record_str in self.bad_records[record_type_str]:
333                 # Attempt to add the bad record, which should fail. Then, attempt to query for and delete
334                 # it. Since it shouldn't exist, these should fail too.
335                 try:
336                     self.add_record(self.custom_zone, "testrecord", record_type_str, record_str, assertion=False)
337                     self.assert_num_records(self.custom_zone, "testrecord", record_type_str, expected_num=0)
338                     self.delete_record(self.custom_zone, "testrecord", record_type_str, record_str, assertion=False)
339                 except AssertionError as e:
340                     print e
341                     num_failures = num_failures + 1
342
343         # Also try to update valid records to invalid ones, making sure this fails
344         for record_type_str in self.bad_records:
345             for record_str in self.bad_records[record_type_str]:
346                 good_record_str = self.good_records[record_type_str][0]
347                 self.add_record(self.custom_zone, "testrecord", record_type_str, good_record_str)
348                 try:
349                     self.add_record(self.custom_zone, "testrecord", record_type_str, record_str, assertion=False)
350                 except AssertionError as e:
351                     print e
352                     num_failures = num_failures + 1
353                 self.delete_record(self.custom_zone, "testrecord", record_type_str, good_record_str)
354
355         self.assertTrue(num_failures == 0, "Failed to reject invalid commands. Total failures: %d." % num_failures)
356
357     def test_add_duplicate_different_type(self):
358         """
359         Attempt to add some values which have the same name as
360         existing ones, just a different type.
361         """
362         num_failures = 0
363         for record_type_str_1 in self.good_records:
364             record1 = self.good_records[record_type_str_1][0]
365             self.add_record(self.custom_zone, "testrecord", record_type_str_1, record1)
366             for record_type_str_2 in self.good_records:
367                 if record_type_str_1 == record_type_str_2:
368                     continue
369
370                 record2 = self.good_records[record_type_str_2][0]
371
372                 has_a = record_type_str_1 == 'A' or record_type_str_2 == 'A'
373                 has_aaaa = record_type_str_1 == 'AAAA' or record_type_str_2 == 'AAAA'
374                 has_cname = record_type_str_1 == 'CNAME' or record_type_str_2 == 'CNAME'
375                 has_ptr = record_type_str_1 == 'PTR' or record_type_str_2 == 'PTR'
376                 has_mx = record_type_str_1 == 'MX' or record_type_str_2 == 'MX'
377                 has_srv = record_type_str_1 == 'SRV' or record_type_str_2 == 'SRV'
378                 has_txt = record_type_str_1 == 'TXT' or record_type_str_2 == 'TXT'
379
380                 # If we attempt to add any record except A or AAAA when we already have an NS record,
381                 # the add should fail.
382                 add_error_ok = False
383                 if record_type_str_1 == 'NS' and not has_a and not has_aaaa:
384                     add_error_ok = True
385                 # If we attempt to add a CNAME when an A, PTR or MX record exists, the add should fail.
386                 if record_type_str_2 == 'CNAME' and (has_ptr or has_mx or has_a or has_aaaa):
387                     add_error_ok = True
388                 # If we have a CNAME, adding an A, AAAA, SRV or TXT record should fail.
389                 # If we have an A, AAAA, SRV or TXT record, adding a CNAME should fail.
390                 if has_cname and (has_a or has_aaaa or has_srv or has_txt):
391                     add_error_ok = True
392
393                 try:
394                     self.add_record(self.custom_zone, "testrecord", record_type_str_2, record2)
395                     if add_error_ok:
396                         num_failures = num_failures + 1
397                         print("Expected error when adding %s while a %s existed."
398                               % (record_type_str_2, record_type_str_1))
399                 except AssertionError as e:
400                     if not add_error_ok:
401                         num_failures = num_failures + 1
402                         print("Didn't expect error when adding %s while a %s existed."
403                               % (record_type_str_2, record_type_str_1))
404
405                 if not add_error_ok:
406                     # In the "normal" case, we expect the add to work and us to have one of each type of record afterwards.
407                     expected_num_type_1 = 1
408                     expected_num_type_2 = 1
409
410                     # If we have an MX record, a PTR record should replace it when added.
411                     # If we have a PTR record, an MX record should replace it when added.
412                     if has_ptr and has_mx:
413                         expected_num_type_1 = 0
414
415                     # If we have a CNAME, SRV or TXT record, a PTR or MX record should replace it when added.
416                     if (has_cname or has_srv or has_txt) and (record_type_str_2 == 'PTR' or record_type_str_2 == 'MX'):
417                         expected_num_type_1 = 0
418
419                     if (record_type_str_1 == 'NS' and (has_a or has_aaaa)):
420                         expected_num_type_2 = 0
421
422                     try:
423                         self.assert_num_records(self.custom_zone, "testrecord", record_type_str_1, expected_num=expected_num_type_1)
424                     except AssertionError as e:
425                         num_failures = num_failures + 1
426                         print("Expected %s %s records after adding a %s record and a %s record already existed."
427                               % (expected_num_type_1, record_type_str_1, record_type_str_2, record_type_str_1))
428                     try:
429                         self.assert_num_records(self.custom_zone, "testrecord", record_type_str_2, expected_num=expected_num_type_2)
430                     except AssertionError as e:
431                         num_failures = num_failures + 1
432                         print("Expected %s %s records after adding a %s record and a %s record already existed."
433                               % (expected_num_type_2, record_type_str_2, record_type_str_2, record_type_str_1))
434
435                 try:
436                     self.delete_record(self.custom_zone, "testrecord", record_type_str_2, record2)
437                 except AssertionError as e:
438                     pass
439
440             self.delete_record(self.custom_zone, "testrecord", record_type_str_1, record1)
441
442         self.assertTrue(num_failures == 0, "Failed collision and replacement behavior. Total failures: %d." % num_failures)
443
444     # Windows fails this test in the same way we do.
445     def _test_cname(self):
446         """
447         Test some special properties of CNAME records.
448         """
449
450         # RFC 1912: When there is a CNAME record, there must not be any other records with the same alias
451         cname_record = self.good_records["CNAME"][1]
452         self.add_record(self.custom_zone, "testrecord", "CNAME", cname_record)
453
454         for record_type_str in self.good_records:
455             other_record = self.good_records[record_type_str][0]
456             self.add_record(self.custom_zone, "testrecord", record_type_str, other_record, assertion=False)
457             self.assert_num_records(self.custom_zone, "testrecord", record_type_str, expected_num=0)
458
459         # RFC 2181: MX & NS records must not be allowed to point to a CNAME alias
460         mx_record = "testrecord 1"
461         ns_record = "testrecord"
462
463         self.add_record(self.custom_zone, "mxrec", "MX", mx_record, assertion=False)
464         self.add_record(self.custom_zone, "nsrec", "NS", ns_record, assertion=False)
465
466         self.delete_record(self.custom_zone, "testrecord", "CNAME", cname_record)
467
468     def test_add_duplicate_value(self):
469         """
470         Make sure that we can't add duplicate values of any type.
471         """
472         for record_type_str in self.good_records:
473             record = self.good_records[record_type_str][0]
474
475             self.add_record(self.custom_zone, "testrecord", record_type_str, record)
476             self.add_record(self.custom_zone, "testrecord", record_type_str, record, assertion=False)
477             self.assert_num_records(self.custom_zone, "testrecord", record_type_str)
478             self.delete_record(self.custom_zone, "testrecord", record_type_str, record)
479
480     def test_add_similar_value(self):
481         """
482         Attempt to add values with the same name and type in the same
483         zone. This should work, and should result in both values
484         existing (except with some types).
485         """
486         for record_type_str in self.good_records:
487             for i in range(1, len(self.good_records[record_type_str])):
488                 record1 = self.good_records[record_type_str][i-1]
489                 record2 = self.good_records[record_type_str][i]
490
491                 if record_type_str == 'CNAME':
492                     continue
493                 # We expect CNAME records to override one another, as
494                 # an alias can only map to one CNAME record.
495                 # Also, on Windows, when the empty string is added and
496                 # another record is added afterwards, the empty string
497                 # will be silently overridden by the new one, so it
498                 # fails this test for the empty string.
499                 expected_num = 1 if record_type_str == 'CNAME' else 2
500
501                 self.add_record(self.custom_zone, "testrecord", record_type_str, record1)
502                 self.add_record(self.custom_zone, "testrecord", record_type_str, record2)
503                 self.assert_num_records(self.custom_zone, "testrecord", record_type_str, expected_num=expected_num)
504                 self.delete_record(self.custom_zone, "testrecord", record_type_str, record1)
505                 self.delete_record(self.custom_zone, "testrecord", record_type_str, record2)
506
507     def assert_record(self, zone, name, record_type_str, expected_record_str,
508                       assertion=True, client_version=dnsserver.DNS_CLIENT_VERSION_LONGHORN):
509         """
510         Asserts whether or not the given record with the given type exists in the
511         given zone.
512         """
513         try:
514             _, result = self.query_records(zone, name, record_type_str)
515         except RuntimeError as e:
516             if assertion:
517                 raise AssertionError("Record '%s' of type '%s' was not present when it should have been."
518                                      % (expected_record_str, record_type_str))
519             else:
520                 return
521
522         found = False
523         for record in result.rec[0].records:
524             if record.data == expected_record_str:
525                 found = True
526                 break
527
528         if found and not assertion:
529             raise AssertionError("Record '%s' of type '%s' was present when it shouldn't have been." % (expected_record_str, record_type_str))
530         elif not found and assertion:
531             raise AssertionError("Record '%s' of type '%s' was not present when it should have been." % (expected_record_str, record_type_str))
532
533     def assert_num_records(self, zone, name, record_type_str, expected_num=1,
534                            client_version=dnsserver.DNS_CLIENT_VERSION_LONGHORN):
535         """
536         Asserts that there are a given amount of records with the given type in
537         the given zone.
538         """
539         try:
540             _, result = self.query_records(zone, name, record_type_str)
541             num_results = len(result.rec[0].records)
542             if not num_results == expected_num:
543                 raise AssertionError("There were %d records of type '%s' with the name '%s' when %d were expected."
544                                      % (num_results, record_type_str, name, expected_num))
545         except RuntimeError:
546             if not expected_num == 0:
547                 raise AssertionError("There were no records of type '%s' with the name '%s' when %d were expected."
548                                      % (record_type_str, name, expected_num))
549
550     def query_records(self, zone, name, record_type_str, client_version=dnsserver.DNS_CLIENT_VERSION_LONGHORN):
551         return self.conn.DnssrvEnumRecords2(client_version,
552                                             0,
553                                             self.server,
554                                             zone,
555                                             name,
556                                             None,
557                                             self.record_type_int(record_type_str),
558                                             dnsserver.DNS_RPC_VIEW_AUTHORITY_DATA | dnsserver.DNS_RPC_VIEW_NO_CHILDREN,
559                                             None,
560                                             None)
561
562     def record_obj_from_str(self, record_type_str, record_str):
563         if record_type_str == 'A':
564             return ARecord(record_str)
565         elif record_type_str == 'AAAA':
566             return AAAARecord(record_str)
567         elif record_type_str == 'PTR':
568             return PTRRecord(record_str)
569         elif record_type_str == 'CNAME':
570             return CNameRecord(record_str)
571         elif record_type_str == 'NS':
572             return NSRecord(record_str)
573         elif record_type_str == 'MX':
574             split = record_str.split(' ')
575             return MXRecord(split[0], int(split[1]))
576         elif record_type_str == 'SRV':
577             split = record_str.split(' ')
578             target = split[0]
579             port = int(split[1])
580             priority = int(split[2])
581             weight = int(split[3])
582             return SRVRecord(target, port, priority, weight)
583         elif record_type_str == 'TXT':
584             return TXTRecord(record_str)
585
586     def record_type_int(self, record_type_str):
587         if record_type_str == 'A':
588             return dnsp.DNS_TYPE_A
589         elif record_type_str == 'AAAA':
590             return dnsp.DNS_TYPE_AAAA
591         elif record_type_str == 'PTR':
592             return dnsp.DNS_TYPE_PTR
593         elif record_type_str == 'CNAME':
594             return dnsp.DNS_TYPE_CNAME
595         elif record_type_str == 'NS':
596             return dnsp.DNS_TYPE_NS
597         elif record_type_str == 'MX':
598             return dnsp.DNS_TYPE_MX
599         elif record_type_str == 'SRV':
600             return dnsp.DNS_TYPE_SRV
601         elif record_type_str == 'TXT':
602             return dnsp.DNS_TYPE_TXT
603
604     def add_record(self, zone, name, record_type_str, record_str,
605                    assertion=True, client_version=dnsserver.DNS_CLIENT_VERSION_LONGHORN):
606         """
607         Attempts to add a map from the given name to a record of the given type,
608         in the given zone.
609         Also asserts whether or not the add was successful.
610         This can also update existing records if they have the same name.
611         """
612         record = self.record_obj_from_str(record_type_str, record_str)
613         add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
614         add_rec_buf.rec = record
615
616         try:
617             self.conn.DnssrvUpdateRecord2(client_version,
618                                           0,
619                                           self.server,
620                                           zone,
621                                           name,
622                                           add_rec_buf,
623                                           None)
624             if not assertion:
625                 raise AssertionError("Successfully added record '%s' of type '%s', which should have failed."
626                                      % (record_str, record_type_str))
627         except RuntimeError as e:
628             if assertion:
629                 raise AssertionError("Failed to add record '%s' of type '%s', which should have succeeded. Error was '%s'."
630                                      % (record_str, record_type_str, str(e)))
631
632     def delete_record(self, zone, name, record_type_str, record_str,
633                       assertion=True, client_version=dnsserver.DNS_CLIENT_VERSION_LONGHORN):
634         """
635         Attempts to delete a record with the given name, record and record type
636         from the given zone.
637         Also asserts whether or not the deletion was successful.
638         """
639         record = self.record_obj_from_str(record_type_str, record_str)
640         del_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
641         del_rec_buf.rec = record
642
643         try:
644             self.conn.DnssrvUpdateRecord2(client_version,
645                                                    0,
646                                                    self.server,
647                                                    zone,
648                                                    name,
649                                                    None,
650                                                    del_rec_buf)
651             if not assertion:
652                 raise AssertionError("Successfully deleted record '%s' of type '%s', which should have failed." % (record_str, record_type_str))
653         except RuntimeError as e:
654             if assertion:
655                 raise AssertionError("Failed to delete record '%s' of type '%s', which should have succeeded. Error was '%s'." % (record_str, record_type_str, str(e)))
656
657     def test_query2(self):
658         typeid, result = self.conn.DnssrvQuery2(dnsserver.DNS_CLIENT_VERSION_W2K,
659                                                 0,
660                                                 self.server,
661                                                 None,
662                                                 'ServerInfo')
663         self.assertEquals(dnsserver.DNSSRV_TYPEID_SERVER_INFO_W2K, typeid)
664
665         typeid, result = self.conn.DnssrvQuery2(dnsserver.DNS_CLIENT_VERSION_DOTNET,
666                                                 0,
667                                                 self.server,
668                                                 None,
669                                                 'ServerInfo')
670         self.assertEquals(dnsserver.DNSSRV_TYPEID_SERVER_INFO_DOTNET, typeid)
671
672         typeid, result = self.conn.DnssrvQuery2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
673                                                 0,
674                                                 self.server,
675                                                 None,
676                                                 'ServerInfo')
677         self.assertEquals(dnsserver.DNSSRV_TYPEID_SERVER_INFO, typeid)
678
679     def test_operation2(self):
680         client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
681         rev_zone = '1.168.192.in-addr.arpa'
682
683         zone_create = dnsserver.DNS_RPC_ZONE_CREATE_INFO_LONGHORN()
684         zone_create.pszZoneName = rev_zone
685         zone_create.dwZoneType = dnsp.DNS_ZONE_TYPE_PRIMARY
686         zone_create.fAllowUpdate = dnsp.DNS_ZONE_UPDATE_SECURE
687         zone_create.fAging = 0
688         zone_create.dwDpFlags = dnsserver.DNS_DP_DOMAIN_DEFAULT
689
690         # Create zone
691         self.conn.DnssrvOperation2(client_version,
692                                     0,
693                                     self.server,
694                                     None,
695                                     0,
696                                     'ZoneCreate',
697                                     dnsserver.DNSSRV_TYPEID_ZONE_CREATE,
698                                     zone_create)
699
700         request_filter = (dnsserver.DNS_ZONE_REQUEST_REVERSE |
701                             dnsserver.DNS_ZONE_REQUEST_PRIMARY)
702         _, zones = self.conn.DnssrvComplexOperation2(client_version,
703                                                      0,
704                                                      self.server,
705                                                      None,
706                                                      'EnumZones',
707                                                      dnsserver.DNSSRV_TYPEID_DWORD,
708                                                      request_filter)
709         self.assertEquals(1, zones.dwZoneCount)
710
711         # Delete zone
712         self.conn.DnssrvOperation2(client_version,
713                                     0,
714                                     self.server,
715                                     rev_zone,
716                                     0,
717                                     'DeleteZoneFromDs',
718                                     dnsserver.DNSSRV_TYPEID_NULL,
719                                     None)
720
721         typeid, zones = self.conn.DnssrvComplexOperation2(client_version,
722                                                             0,
723                                                             self.server,
724                                                             None,
725                                                             'EnumZones',
726                                                             dnsserver.DNSSRV_TYPEID_DWORD,
727                                                             request_filter)
728         self.assertEquals(0, zones.dwZoneCount)
729
730
731     def test_complexoperation2(self):
732         client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
733         request_filter = (dnsserver.DNS_ZONE_REQUEST_FORWARD |
734                             dnsserver.DNS_ZONE_REQUEST_PRIMARY)
735
736         typeid, zones = self.conn.DnssrvComplexOperation2(client_version,
737                                                             0,
738                                                             self.server,
739                                                             None,
740                                                             'EnumZones',
741                                                             dnsserver.DNSSRV_TYPEID_DWORD,
742                                                             request_filter)
743         self.assertEquals(dnsserver.DNSSRV_TYPEID_ZONE_LIST, typeid)
744         self.assertEquals(3, zones.dwZoneCount)
745
746         request_filter = (dnsserver.DNS_ZONE_REQUEST_REVERSE |
747                             dnsserver.DNS_ZONE_REQUEST_PRIMARY)
748         typeid, zones = self.conn.DnssrvComplexOperation2(client_version,
749                                                             0,
750                                                             self.server,
751                                                             None,
752                                                             'EnumZones',
753                                                             dnsserver.DNSSRV_TYPEID_DWORD,
754                                                             request_filter)
755         self.assertEquals(dnsserver.DNSSRV_TYPEID_ZONE_LIST, typeid)
756         self.assertEquals(0, zones.dwZoneCount)
757
758     def test_enumrecords2(self):
759         client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
760         record_type = dnsp.DNS_TYPE_NS
761         select_flags = (dnsserver.DNS_RPC_VIEW_ROOT_HINT_DATA |
762                         dnsserver.DNS_RPC_VIEW_ADDITIONAL_DATA)
763         _, roothints = self.conn.DnssrvEnumRecords2(client_version,
764                                                     0,
765                                                     self.server,
766                                                     '..RootHints',
767                                                     '.',
768                                                     None,
769                                                     record_type,
770                                                     select_flags,
771                                                     None,
772                                                     None)
773         self.assertEquals(14, roothints.count)  # 1 NS + 13 A records (a-m)
774
775     def test_updaterecords2(self):
776         client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
777         record_type = dnsp.DNS_TYPE_A
778         select_flags = dnsserver.DNS_RPC_VIEW_AUTHORITY_DATA
779
780         name = 'dummy'
781         rec = ARecord('1.2.3.4')
782         rec2 = ARecord('5.6.7.8')
783
784         # Add record
785         add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
786         add_rec_buf.rec = rec
787         self.conn.DnssrvUpdateRecord2(client_version,
788                                         0,
789                                         self.server,
790                                         self.zone,
791                                         name,
792                                         add_rec_buf,
793                                         None)
794
795         _, result = self.conn.DnssrvEnumRecords2(client_version,
796                                                  0,
797                                                  self.server,
798                                                  self.zone,
799                                                  name,
800                                                  None,
801                                                  record_type,
802                                                  select_flags,
803                                                  None,
804                                                  None)
805         self.assertEquals(1, result.count)
806         self.assertEquals(1, result.rec[0].wRecordCount)
807         self.assertEquals(dnsp.DNS_TYPE_A, result.rec[0].records[0].wType)
808         self.assertEquals('1.2.3.4', result.rec[0].records[0].data)
809
810         # Update record
811         add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
812         add_rec_buf.rec = rec2
813         del_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
814         del_rec_buf.rec = rec
815         self.conn.DnssrvUpdateRecord2(client_version,
816                                         0,
817                                         self.server,
818                                         self.zone,
819                                         name,
820                                         add_rec_buf,
821                                         del_rec_buf)
822
823         buflen, result = self.conn.DnssrvEnumRecords2(client_version,
824                                                         0,
825                                                         self.server,
826                                                         self.zone,
827                                                         name,
828                                                         None,
829                                                         record_type,
830                                                         select_flags,
831                                                         None,
832                                                         None)
833         self.assertEquals(1, result.count)
834         self.assertEquals(1, result.rec[0].wRecordCount)
835         self.assertEquals(dnsp.DNS_TYPE_A, result.rec[0].records[0].wType)
836         self.assertEquals('5.6.7.8', result.rec[0].records[0].data)
837
838         # Delete record
839         del_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
840         del_rec_buf.rec = rec2
841         self.conn.DnssrvUpdateRecord2(client_version,
842                                         0,
843                                         self.server,
844                                         self.zone,
845                                         name,
846                                         None,
847                                         del_rec_buf)
848
849         self.assertRaises(RuntimeError, self.conn.DnssrvEnumRecords2,
850                                         client_version,
851                                         0,
852                                         self.server,
853                                         self.zone,
854                                         name,
855                                         None,
856                                         record_type,
857                                         select_flags,
858                                         None,
859                                         None)
860
861     # The following tests do not pass against Samba because the owner and
862     # group are not consistent with Windows, as well as some ACEs.
863     #
864     # The following ACE are also required for 2012R2:
865     #
866     # (OA;CIIO;WP;ea1b7b93-5e48-46d5-bc6c-4df4fda78a35;bf967a86-0de6-11d0-a285-00aa003049e2;PS)
867     # (OA;OICI;RPWP;3f78c3e5-f79a-46bd-a0b8-9d18116ddc79;;PS)"
868     #
869     # [TPM + Allowed-To-Act-On-Behalf-Of-Other-Identity]
870     def test_security_descriptor_msdcs_zone(self):
871         """
872         Make sure that security descriptors of the msdcs zone is
873         as expected.
874         """
875
876         zones = self.samdb.search(base="DC=ForestDnsZones,%s" % self.samdb.get_default_basedn(),
877                                   scope=ldb.SCOPE_SUBTREE,
878                                   expression="(&(objectClass=dnsZone)(name=_msdcs*))",
879                                   attrs=["nTSecurityDescriptor", "objectClass"])
880         self.assertEqual(len(zones), 1)
881         self.assertTrue("nTSecurityDescriptor" in zones[0])
882         tmp = zones[0]["nTSecurityDescriptor"][0]
883         utils = sd_utils.SDUtils(self.samdb)
884         sd = ndr_unpack(security.descriptor, tmp)
885
886         domain_sid = security.dom_sid(self.samdb.get_domain_sid())
887
888         res = self.samdb.search(base=self.samdb.get_default_basedn(), scope=ldb.SCOPE_SUBTREE,
889                                 expression="(sAMAccountName=DnsAdmins)",
890                                 attrs=["objectSid"])
891
892         dns_admin = str(ndr_unpack(security.dom_sid, res[0]['objectSid'][0]))
893
894         packed_sd = descriptor.sddl2binary("O:SYG:BA" \
895                                            "D:AI(A;;RPWPCRCCDCLCLORCWOWDSDDTSW;;;DA)" \
896                                            "(A;;CC;;;AU)" \
897                                            "(A;;RPLCLORC;;;WD)" \
898                                            "(A;;RPWPCRCCDCLCLORCWOWDSDDTSW;;;SY)" \
899                                            "(A;CI;RPWPCRCCDCLCRCWOWDSDDTSW;;;ED)",
900                                            domain_sid, {"DnsAdmins": dns_admin})
901         expected_sd = descriptor.get_clean_sd(ndr_unpack(security.descriptor, packed_sd))
902
903         diff = descriptor.get_diff_sds(expected_sd, sd, domain_sid)
904         self.assertEqual(diff, '', "SD of msdcs zone different to expected.\n"
905                          "Difference was:\n%s\nExpected: %s\nGot: %s" %
906                          (diff, expected_sd.as_sddl(utils.domain_sid),
907                           sd.as_sddl(utils.domain_sid)))
908
909     def test_security_descriptor_forest_zone(self):
910         """
911         Make sure that security descriptors of forest dns zones are
912         as expected.
913         """
914         forest_zone = "test_forest_zone"
915         zone_create_info = dnsserver.DNS_RPC_ZONE_CREATE_INFO_LONGHORN()
916         zone_create_info.dwZoneType = dnsp.DNS_ZONE_TYPE_PRIMARY
917         zone_create_info.fAging = 0
918         zone_create_info.fDsIntegrated = 1
919         zone_create_info.fLoadExisting = 1
920
921         zone_create_info.pszZoneName = forest_zone
922         zone_create_info.dwDpFlags = dnsserver.DNS_DP_FOREST_DEFAULT
923
924         self.conn.DnssrvOperation2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
925                                    0,
926                                    self.server,
927                                    None,
928                                    0,
929                                    'ZoneCreate',
930                                    dnsserver.DNSSRV_TYPEID_ZONE_CREATE,
931                                    zone_create_info)
932
933         partition_dn = self.samdb.get_default_basedn()
934         partition_dn.add_child("DC=ForestDnsZones")
935         zones = self.samdb.search(base=partition_dn, scope=ldb.SCOPE_SUBTREE,
936                                   expression="(name=%s)" % forest_zone,
937                                   attrs=["nTSecurityDescriptor"])
938         self.assertEqual(len(zones), 1)
939         current_dn = zones[0].dn
940         self.assertTrue("nTSecurityDescriptor" in zones[0])
941         tmp = zones[0]["nTSecurityDescriptor"][0]
942         utils = sd_utils.SDUtils(self.samdb)
943         sd = ndr_unpack(security.descriptor, tmp)
944
945         domain_sid = security.dom_sid(self.samdb.get_domain_sid())
946
947         res = self.samdb.search(base=self.samdb.get_default_basedn(),
948                                 scope=ldb.SCOPE_SUBTREE,
949                                 expression="(sAMAccountName=DnsAdmins)",
950                                 attrs=["objectSid"])
951
952         dns_admin = str(ndr_unpack(security.dom_sid, res[0]['objectSid'][0]))
953
954         packed_sd = descriptor.sddl2binary("O:DAG:DA" \
955                                            "D:AI(A;;RPWPCRCCDCLCLORCWOWDSDDTSW;;;DA)" \
956                                            "(A;;CC;;;AU)" \
957                                            "(A;;RPLCLORC;;;WD)" \
958                                            "(A;;RPWPCRCCDCLCLORCWOWDSDDTSW;;;SY)" \
959                                            "(A;CI;RPWPCRCCDCLCRCWOWDSDDTSW;;;ED)",
960                                            domain_sid, {"DnsAdmins": dns_admin})
961         expected_sd = descriptor.get_clean_sd(ndr_unpack(security.descriptor, packed_sd))
962
963         packed_msdns = descriptor.get_dns_forest_microsoft_dns_descriptor(domain_sid,
964                                                                           {"DnsAdmins": dns_admin})
965         expected_msdns_sd = descriptor.get_clean_sd(ndr_unpack(security.descriptor, packed_msdns))
966
967         packed_part_sd = descriptor.get_dns_partition_descriptor(domain_sid)
968         expected_part_sd = descriptor.get_clean_sd(ndr_unpack(security.descriptor,
969                                                               packed_part_sd))
970         try:
971             msdns_dn = ldb.Dn(self.samdb, "CN=MicrosoftDNS,%s" % str(partition_dn))
972             security_desc_dict = [(current_dn.get_linearized(),  expected_sd),
973                                   (msdns_dn.get_linearized(), expected_msdns_sd),
974                                   (partition_dn.get_linearized(), expected_part_sd)]
975
976             for (key, sec_desc) in security_desc_dict:
977                 zones = self.samdb.search(base=key, scope=ldb.SCOPE_BASE,
978                                           attrs=["nTSecurityDescriptor"])
979                 self.assertTrue("nTSecurityDescriptor" in zones[0])
980                 tmp = zones[0]["nTSecurityDescriptor"][0]
981                 utils = sd_utils.SDUtils(self.samdb)
982
983                 sd = ndr_unpack(security.descriptor, tmp)
984                 diff = descriptor.get_diff_sds(sec_desc, sd, domain_sid)
985
986                 self.assertEqual(diff, '', "Security descriptor of forest DNS zone with DN '%s' different to expected. Difference was:\n%s\nExpected: %s\nGot: %s"
987                                  % (key, diff, sec_desc.as_sddl(utils.domain_sid), sd.as_sddl(utils.domain_sid)))
988
989         finally:
990             self.conn.DnssrvOperation2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
991                                        0,
992                                        self.server,
993                                        forest_zone,
994                                        0,
995                                        'DeleteZoneFromDs',
996                                        dnsserver.DNSSRV_TYPEID_NULL,
997                                        None)
998
999     def test_security_descriptor_domain_zone(self):
1000         """
1001         Make sure that security descriptors of domain dns zones are
1002         as expected.
1003         """
1004
1005         partition_dn = self.samdb.get_default_basedn()
1006         partition_dn.add_child("DC=DomainDnsZones")
1007         zones = self.samdb.search(base=partition_dn, scope=ldb.SCOPE_SUBTREE,
1008                                   expression="(name=%s)" % self.custom_zone,
1009                                   attrs=["nTSecurityDescriptor"])
1010         self.assertEqual(len(zones), 1)
1011         current_dn = zones[0].dn
1012         self.assertTrue("nTSecurityDescriptor" in zones[0])
1013         tmp = zones[0]["nTSecurityDescriptor"][0]
1014         utils = sd_utils.SDUtils(self.samdb)
1015         sd = ndr_unpack(security.descriptor, tmp)
1016         sddl = sd.as_sddl(utils.domain_sid)
1017
1018         domain_sid = security.dom_sid(self.samdb.get_domain_sid())
1019
1020         res = self.samdb.search(base=self.samdb.get_default_basedn(), scope=ldb.SCOPE_SUBTREE,
1021                                 expression="(sAMAccountName=DnsAdmins)",
1022                                 attrs=["objectSid"])
1023
1024         dns_admin = str(ndr_unpack(security.dom_sid, res[0]['objectSid'][0]))
1025
1026         packed_sd = descriptor.sddl2binary("O:DAG:DA" \
1027                                            "D:AI(A;;RPWPCRCCDCLCLORCWOWDSDDTSW;;;DA)" \
1028                                            "(A;;CC;;;AU)" \
1029                                            "(A;;RPLCLORC;;;WD)" \
1030                                            "(A;;RPWPCRCCDCLCLORCWOWDSDDTSW;;;SY)" \
1031                                            "(A;CI;RPWPCRCCDCLCRCWOWDSDDTSW;;;ED)",
1032                                            domain_sid, {"DnsAdmins": dns_admin})
1033         expected_sd = descriptor.get_clean_sd(ndr_unpack(security.descriptor, packed_sd))
1034
1035         packed_msdns = descriptor.get_dns_domain_microsoft_dns_descriptor(domain_sid,
1036                                                                           {"DnsAdmins": dns_admin})
1037         expected_msdns_sd = descriptor.get_clean_sd(ndr_unpack(security.descriptor, packed_msdns))
1038
1039         packed_part_sd = descriptor.get_dns_partition_descriptor(domain_sid)
1040         expected_part_sd = descriptor.get_clean_sd(ndr_unpack(security.descriptor,
1041                                                               packed_part_sd))
1042
1043         msdns_dn = ldb.Dn(self.samdb, "CN=MicrosoftDNS,%s" % str(partition_dn))
1044         security_desc_dict = [(current_dn.get_linearized(),  expected_sd),
1045                               (msdns_dn.get_linearized(), expected_msdns_sd),
1046                               (partition_dn.get_linearized(), expected_part_sd)]
1047
1048         for (key, sec_desc) in security_desc_dict:
1049             zones = self.samdb.search(base=key, scope=ldb.SCOPE_BASE,
1050                                       attrs=["nTSecurityDescriptor"])
1051             self.assertTrue("nTSecurityDescriptor" in zones[0])
1052             tmp = zones[0]["nTSecurityDescriptor"][0]
1053             utils = sd_utils.SDUtils(self.samdb)
1054
1055             sd = ndr_unpack(security.descriptor, tmp)
1056             diff = descriptor.get_diff_sds(sec_desc, sd, domain_sid)
1057
1058             self.assertEqual(diff, '', "Security descriptor of domain DNS zone with DN '%s' different to expected. Difference was:\n%s\nExpected: %s\nGot: %s"
1059                              % (key, diff, sec_desc.as_sddl(utils.domain_sid), sd.as_sddl(utils.domain_sid)))