CVE-2016-0771: tests/dns: modify tests to check via RPC
[samba.git] / python / samba / tests / dns.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin  <kai@samba.org> 2011
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 import os
19 import struct
20 import random
21 import socket
22 import samba.ndr as ndr
23 from samba import credentials, param
24 from samba.tests import TestCase
25 from samba.dcerpc import dns, dnsp, dnsserver
26
27 FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
28
29 # This timeout only has relevance when testing against Windows
30 # Format errors tend to return patchy responses, so a timeout is needed.
31 timeout = None
32
33 def make_txt_record(records):
34     rdata_txt = dns.txt_record()
35     s_list = dnsp.string_list()
36     s_list.count = len(records)
37     s_list.str = records
38     rdata_txt.txt = s_list
39     return rdata_txt
40
41 class DNSTest(TestCase):
42
43     def get_loadparm(self):
44         lp = param.LoadParm()
45         lp.load(os.getenv("SMB_CONF_PATH"))
46         return lp
47
48     def errstr(self, errcode):
49         "Return a readable error code"
50         string_codes = [
51             "OK",
52             "FORMERR",
53             "SERVFAIL",
54             "NXDOMAIN",
55             "NOTIMP",
56             "REFUSED",
57             "YXDOMAIN",
58             "YXRRSET",
59             "NXRRSET",
60             "NOTAUTH",
61             "NOTZONE",
62         ]
63
64         return string_codes[errcode]
65
66
67     def assert_dns_rcode_equals(self, packet, rcode):
68         "Helper function to check return code"
69         p_errcode = packet.operation & 0x000F
70         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
71                             (self.errstr(rcode), self.errstr(p_errcode)))
72
73     def assert_dns_opcode_equals(self, packet, opcode):
74         "Helper function to check opcode"
75         p_opcode = packet.operation & 0x7800
76         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
77                             (opcode, p_opcode))
78
79     def make_name_packet(self, opcode, qid=None):
80         "Helper creating a dns.name_packet"
81         p = dns.name_packet()
82         if qid is None:
83             p.id = random.randint(0x0, 0xffff)
84         p.operation = opcode
85         p.questions = []
86         return p
87
88     def finish_name_packet(self, packet, questions):
89         "Helper to finalize a dns.name_packet"
90         packet.qdcount = len(questions)
91         packet.questions = questions
92
93     def make_name_question(self, name, qtype, qclass):
94         "Helper creating a dns.name_question"
95         q = dns.name_question()
96         q.name = name
97         q.question_type = qtype
98         q.question_class = qclass
99         return q
100
101     def get_dns_domain(self):
102         "Helper to get dns domain"
103         return os.getenv('REALM', 'example.com').lower()
104
105     def dns_transaction_udp(self, packet, host=os.getenv('SERVER_IP'),
106                             dump=False, timeout=timeout):
107         "send a DNS query and read the reply"
108         s = None
109         try:
110             send_packet = ndr.ndr_pack(packet)
111             if dump:
112                 print self.hexdump(send_packet)
113             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
114             s.settimeout(timeout)
115             s.connect((host, 53))
116             s.send(send_packet, 0)
117             recv_packet = s.recv(2048, 0)
118             if dump:
119                 print self.hexdump(recv_packet)
120             return ndr.ndr_unpack(dns.name_packet, recv_packet)
121         finally:
122             if s is not None:
123                 s.close()
124
125     def dns_transaction_tcp(self, packet, host=os.getenv('SERVER_IP'),
126                             dump=False, timeout=timeout):
127         "send a DNS query and read the reply"
128         s = None
129         try:
130             send_packet = ndr.ndr_pack(packet)
131             if dump:
132                 print self.hexdump(send_packet)
133             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
134             s.settimeout(timeout)
135             s.connect((host, 53))
136             tcp_packet = struct.pack('!H', len(send_packet))
137             tcp_packet += send_packet
138             s.send(tcp_packet, 0)
139             recv_packet = s.recv(0xffff + 2, 0)
140             if dump:
141                 print self.hexdump(recv_packet)
142             return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
143         finally:
144                 if s is not None:
145                     s.close()
146
147     def hexdump(self, src, length=8):
148         N=0; result=''
149         while src:
150            s,src = src[:length],src[length:]
151            hexa = ' '.join(["%02X"%ord(x) for x in s])
152            s = s.translate(FILTER)
153            result += "%04X   %-*s   %s\n" % (N, length*3, hexa, s)
154            N+=length
155         return result
156
157     def make_txt_update(self, prefix, txt_array):
158         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
159         updates = []
160
161         name = self.get_dns_domain()
162         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
163         updates.append(u)
164         self.finish_name_packet(p, updates)
165
166         updates = []
167         r = dns.res_rec()
168         r.name = "%s.%s" % (prefix, self.get_dns_domain())
169         r.rr_type = dns.DNS_QTYPE_TXT
170         r.rr_class = dns.DNS_QCLASS_IN
171         r.ttl = 900
172         r.length = 0xffff
173         rdata = make_txt_record(txt_array)
174         r.rdata = rdata
175         updates.append(r)
176         p.nscount = len(updates)
177         p.nsrecs = updates
178
179         return p
180
181     def check_query_txt(self, prefix, txt_array):
182         name = "%s.%s" % (prefix, self.get_dns_domain())
183         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
184         questions = []
185
186         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
187         questions.append(q)
188
189         self.finish_name_packet(p, questions)
190         response = self.dns_transaction_udp(p)
191         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
192         self.assertEquals(response.ancount, 1)
193         self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
194
195     def assertIsNotNone(self, item):
196         self.assertTrue(item is not None)
197
198 class TestSimpleQueries(DNSTest):
199
200     def test_one_a_query(self):
201         "create a query packet containing one query record"
202         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
203         questions = []
204
205         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
206         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
207         print "asking for ", q.name
208         questions.append(q)
209
210         self.finish_name_packet(p, questions)
211         response = self.dns_transaction_udp(p)
212         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
213         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
214         self.assertEquals(response.ancount, 1)
215         self.assertEquals(response.answers[0].rdata,
216                           os.getenv('SERVER_IP'))
217
218     def test_one_a_query_tcp(self):
219         "create a query packet containing one query record via TCP"
220         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
221         questions = []
222
223         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
224         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
225         print "asking for ", q.name
226         questions.append(q)
227
228         self.finish_name_packet(p, questions)
229         response = self.dns_transaction_tcp(p)
230         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
231         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
232         self.assertEquals(response.ancount, 1)
233         self.assertEquals(response.answers[0].rdata,
234                           os.getenv('SERVER_IP'))
235
236     def test_one_mx_query(self):
237         "create a query packet causing an empty RCODE_OK answer"
238         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
239         questions = []
240
241         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
242         q = self.make_name_question(name, dns.DNS_QTYPE_MX, dns.DNS_QCLASS_IN)
243         print "asking for ", q.name
244         questions.append(q)
245
246         self.finish_name_packet(p, questions)
247         response = self.dns_transaction_udp(p)
248         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
249         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
250         self.assertEquals(response.ancount, 0)
251
252         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
253         questions = []
254
255         name = "invalid-%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
256         q = self.make_name_question(name, dns.DNS_QTYPE_MX, dns.DNS_QCLASS_IN)
257         print "asking for ", q.name
258         questions.append(q)
259
260         self.finish_name_packet(p, questions)
261         response = self.dns_transaction_udp(p)
262         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
263         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
264         self.assertEquals(response.ancount, 0)
265
266     def test_two_queries(self):
267         "create a query packet containing two query records"
268         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
269         questions = []
270
271         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
272         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
273         questions.append(q)
274
275         name = "%s.%s" % ('bogusname', self.get_dns_domain())
276         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
277         questions.append(q)
278
279         self.finish_name_packet(p, questions)
280         try:
281             response = self.dns_transaction_udp(p)
282             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
283         except socket.timeout:
284             # Windows chooses not to respond to incorrectly formatted queries.
285             # Although this appears to be non-deterministic even for the same
286             # request twice, it also appears to be based on a how poorly the
287             # request is formatted.
288             pass
289
290     def test_qtype_all_query(self):
291         "create a QTYPE_ALL query"
292         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
293         questions = []
294
295         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
296         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_IN)
297         print "asking for ", q.name
298         questions.append(q)
299
300         self.finish_name_packet(p, questions)
301         response = self.dns_transaction_udp(p)
302
303         num_answers = 1
304         dc_ipv6 = os.getenv('SERVER_IPV6')
305         if dc_ipv6 is not None:
306             num_answers += 1
307
308         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
309         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
310         self.assertEquals(response.ancount, num_answers)
311         self.assertEquals(response.answers[0].rdata,
312                           os.getenv('SERVER_IP'))
313         if dc_ipv6 is not None:
314             self.assertEquals(response.answers[1].rdata, dc_ipv6)
315
316     def test_qclass_none_query(self):
317         "create a QCLASS_NONE query"
318         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
319         questions = []
320
321         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
322         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_NONE)
323         questions.append(q)
324
325         self.finish_name_packet(p, questions)
326         try:
327             response = self.dns_transaction_udp(p)
328             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
329         except socket.timeout:
330             # Windows chooses not to respond to incorrectly formatted queries.
331             # Although this appears to be non-deterministic even for the same
332             # request twice, it also appears to be based on a how poorly the
333             # request is formatted.
334             pass
335
336 # Only returns an authority section entry in BIND and Win DNS
337 # FIXME: Enable one Samba implements this feature
338 #    def test_soa_hostname_query(self):
339 #        "create a SOA query for a hostname"
340 #        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
341 #        questions = []
342 #
343 #        name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
344 #        q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
345 #        questions.append(q)
346 #
347 #        self.finish_name_packet(p, questions)
348 #        response = self.dns_transaction_udp(p)
349 #        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
350 #        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
351 #        # We don't get SOA records for single hosts
352 #        self.assertEquals(response.ancount, 0)
353
354     def test_soa_domain_query(self):
355         "create a SOA query for a domain"
356         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
357         questions = []
358
359         name = self.get_dns_domain()
360         q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
361         questions.append(q)
362
363         self.finish_name_packet(p, questions)
364         response = self.dns_transaction_udp(p)
365         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
366         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
367         self.assertEquals(response.ancount, 1)
368         self.assertEquals(response.answers[0].rdata.minimum, 3600)
369
370
371 class TestDNSUpdates(DNSTest):
372
373     def test_two_updates(self):
374         "create two update requests"
375         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
376         updates = []
377
378         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
379         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
380         updates.append(u)
381
382         name = self.get_dns_domain()
383         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
384         updates.append(u)
385
386         self.finish_name_packet(p, updates)
387         try:
388             response = self.dns_transaction_udp(p)
389             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
390         except socket.timeout:
391             # Windows chooses not to respond to incorrectly formatted queries.
392             # Although this appears to be non-deterministic even for the same
393             # request twice, it also appears to be based on a how poorly the
394             # request is formatted.
395             pass
396
397     def test_update_wrong_qclass(self):
398         "create update with DNS_QCLASS_NONE"
399         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
400         updates = []
401
402         name = self.get_dns_domain()
403         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_NONE)
404         updates.append(u)
405
406         self.finish_name_packet(p, updates)
407         response = self.dns_transaction_udp(p)
408         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
409
410     def test_update_prereq_with_non_null_ttl(self):
411         "test update with a non-null TTL"
412         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
413         updates = []
414
415         name = self.get_dns_domain()
416
417         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
418         updates.append(u)
419         self.finish_name_packet(p, updates)
420
421         prereqs = []
422         r = dns.res_rec()
423         r.name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
424         r.rr_type = dns.DNS_QTYPE_TXT
425         r.rr_class = dns.DNS_QCLASS_NONE
426         r.ttl = 1
427         r.length = 0
428         prereqs.append(r)
429
430         p.ancount = len(prereqs)
431         p.answers = prereqs
432
433         try:
434             response = self.dns_transaction_udp(p)
435             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
436         except socket.timeout:
437             # Windows chooses not to respond to incorrectly formatted queries.
438             # Although this appears to be non-deterministic even for the same
439             # request twice, it also appears to be based on a how poorly the
440             # request is formatted.
441             pass
442
443     def test_update_prereq_with_non_null_length(self):
444         "test update with a non-null length"
445         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
446         updates = []
447
448         name = self.get_dns_domain()
449
450         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
451         updates.append(u)
452         self.finish_name_packet(p, updates)
453
454         prereqs = []
455         r = dns.res_rec()
456         r.name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
457         r.rr_type = dns.DNS_QTYPE_TXT
458         r.rr_class = dns.DNS_QCLASS_ANY
459         r.ttl = 0
460         r.length = 1
461         prereqs.append(r)
462
463         p.ancount = len(prereqs)
464         p.answers = prereqs
465
466         response = self.dns_transaction_udp(p)
467         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
468
469     def test_update_prereq_nonexisting_name(self):
470         "test update with a nonexisting name"
471         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
472         updates = []
473
474         name = self.get_dns_domain()
475
476         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
477         updates.append(u)
478         self.finish_name_packet(p, updates)
479
480         prereqs = []
481         r = dns.res_rec()
482         r.name = "idontexist.%s" % self.get_dns_domain()
483         r.rr_type = dns.DNS_QTYPE_TXT
484         r.rr_class = dns.DNS_QCLASS_ANY
485         r.ttl = 0
486         r.length = 0
487         prereqs.append(r)
488
489         p.ancount = len(prereqs)
490         p.answers = prereqs
491
492         response = self.dns_transaction_udp(p)
493         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
494
495     def test_update_add_txt_record(self):
496         "test adding records works"
497         prefix, txt = 'textrec', ['"This is a test"']
498         p = self.make_txt_update(prefix, txt)
499         response = self.dns_transaction_udp(p)
500         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
501         self.check_query_txt(prefix, txt)
502
503     def test_delete_record(self):
504         "Test if deleting records works"
505
506         NAME = "deleterec.%s" % self.get_dns_domain()
507
508         # First, create a record to make sure we have a record to delete.
509         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
510         updates = []
511
512         name = self.get_dns_domain()
513
514         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
515         updates.append(u)
516         self.finish_name_packet(p, updates)
517
518         updates = []
519         r = dns.res_rec()
520         r.name = NAME
521         r.rr_type = dns.DNS_QTYPE_TXT
522         r.rr_class = dns.DNS_QCLASS_IN
523         r.ttl = 900
524         r.length = 0xffff
525         rdata = make_txt_record(['"This is a test"'])
526         r.rdata = rdata
527         updates.append(r)
528         p.nscount = len(updates)
529         p.nsrecs = updates
530
531         response = self.dns_transaction_udp(p)
532         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
533
534         # Now check the record is around
535         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
536         questions = []
537         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
538         questions.append(q)
539
540         self.finish_name_packet(p, questions)
541         response = self.dns_transaction_udp(p)
542         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
543
544         # Now delete the record
545         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
546         updates = []
547
548         name = self.get_dns_domain()
549
550         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
551         updates.append(u)
552         self.finish_name_packet(p, updates)
553
554         updates = []
555         r = dns.res_rec()
556         r.name = NAME
557         r.rr_type = dns.DNS_QTYPE_TXT
558         r.rr_class = dns.DNS_QCLASS_NONE
559         r.ttl = 0
560         r.length = 0xffff
561         rdata = make_txt_record(['"This is a test"'])
562         r.rdata = rdata
563         updates.append(r)
564         p.nscount = len(updates)
565         p.nsrecs = updates
566
567         response = self.dns_transaction_udp(p)
568         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
569
570         # And finally check it's gone
571         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
572         questions = []
573
574         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
575         questions.append(q)
576
577         self.finish_name_packet(p, questions)
578         response = self.dns_transaction_udp(p)
579         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
580
581     def test_readd_record(self):
582         "Test if adding, deleting and then readding a records works"
583
584         NAME = "readdrec.%s" % self.get_dns_domain()
585
586         # Create the record
587         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
588         updates = []
589
590         name = self.get_dns_domain()
591
592         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
593         updates.append(u)
594         self.finish_name_packet(p, updates)
595
596         updates = []
597         r = dns.res_rec()
598         r.name = NAME
599         r.rr_type = dns.DNS_QTYPE_TXT
600         r.rr_class = dns.DNS_QCLASS_IN
601         r.ttl = 900
602         r.length = 0xffff
603         rdata = make_txt_record(['"This is a test"'])
604         r.rdata = rdata
605         updates.append(r)
606         p.nscount = len(updates)
607         p.nsrecs = updates
608
609         response = self.dns_transaction_udp(p)
610         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
611
612         # Now check the record is around
613         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
614         questions = []
615         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
616         questions.append(q)
617
618         self.finish_name_packet(p, questions)
619         response = self.dns_transaction_udp(p)
620         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
621
622         # Now delete the record
623         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
624         updates = []
625
626         name = self.get_dns_domain()
627
628         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
629         updates.append(u)
630         self.finish_name_packet(p, updates)
631
632         updates = []
633         r = dns.res_rec()
634         r.name = NAME
635         r.rr_type = dns.DNS_QTYPE_TXT
636         r.rr_class = dns.DNS_QCLASS_NONE
637         r.ttl = 0
638         r.length = 0xffff
639         rdata = make_txt_record(['"This is a test"'])
640         r.rdata = rdata
641         updates.append(r)
642         p.nscount = len(updates)
643         p.nsrecs = updates
644
645         response = self.dns_transaction_udp(p)
646         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
647
648         # check it's gone
649         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
650         questions = []
651
652         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
653         questions.append(q)
654
655         self.finish_name_packet(p, questions)
656         response = self.dns_transaction_udp(p)
657         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
658
659         # recreate the record
660         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
661         updates = []
662
663         name = self.get_dns_domain()
664
665         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
666         updates.append(u)
667         self.finish_name_packet(p, updates)
668
669         updates = []
670         r = dns.res_rec()
671         r.name = NAME
672         r.rr_type = dns.DNS_QTYPE_TXT
673         r.rr_class = dns.DNS_QCLASS_IN
674         r.ttl = 900
675         r.length = 0xffff
676         rdata = make_txt_record(['"This is a test"'])
677         r.rdata = rdata
678         updates.append(r)
679         p.nscount = len(updates)
680         p.nsrecs = updates
681
682         response = self.dns_transaction_udp(p)
683         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
684
685         # Now check the record is around
686         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
687         questions = []
688         q = self.make_name_question(NAME, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
689         questions.append(q)
690
691         self.finish_name_packet(p, questions)
692         response = self.dns_transaction_udp(p)
693         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
694
695     def test_update_add_mx_record(self):
696         "test adding MX records works"
697         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
698         updates = []
699
700         name = self.get_dns_domain()
701
702         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
703         updates.append(u)
704         self.finish_name_packet(p, updates)
705
706         updates = []
707         r = dns.res_rec()
708         r.name = "%s" % self.get_dns_domain()
709         r.rr_type = dns.DNS_QTYPE_MX
710         r.rr_class = dns.DNS_QCLASS_IN
711         r.ttl = 900
712         r.length = 0xffff
713         rdata = dns.mx_record()
714         rdata.preference = 10
715         rdata.exchange = 'mail.%s' % self.get_dns_domain()
716         r.rdata = rdata
717         updates.append(r)
718         p.nscount = len(updates)
719         p.nsrecs = updates
720
721         response = self.dns_transaction_udp(p)
722         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
723
724         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
725         questions = []
726
727         name = "%s" % self.get_dns_domain()
728         q = self.make_name_question(name, dns.DNS_QTYPE_MX, dns.DNS_QCLASS_IN)
729         questions.append(q)
730
731         self.finish_name_packet(p, questions)
732         response = self.dns_transaction_udp(p)
733         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
734         self.assertEqual(response.ancount, 1)
735         ans = response.answers[0]
736         self.assertEqual(ans.rr_type, dns.DNS_QTYPE_MX)
737         self.assertEqual(ans.rdata.preference, 10)
738         self.assertEqual(ans.rdata.exchange, 'mail.%s' % self.get_dns_domain())
739
740
741 class TestComplexQueries(DNSTest):
742
743     def setUp(self):
744         super(TestComplexQueries, self).setUp()
745         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
746         updates = []
747
748         name = self.get_dns_domain()
749
750         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
751         updates.append(u)
752         self.finish_name_packet(p, updates)
753
754         updates = []
755         r = dns.res_rec()
756         r.name = "cname_test.%s" % self.get_dns_domain()
757         r.rr_type = dns.DNS_QTYPE_CNAME
758         r.rr_class = dns.DNS_QCLASS_IN
759         r.ttl = 900
760         r.length = 0xffff
761         r.rdata = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
762         updates.append(r)
763         p.nscount = len(updates)
764         p.nsrecs = updates
765
766         response = self.dns_transaction_udp(p)
767         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
768
769     def tearDown(self):
770         super(TestComplexQueries, self).tearDown()
771         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
772         updates = []
773
774         name = self.get_dns_domain()
775
776         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
777         updates.append(u)
778         self.finish_name_packet(p, updates)
779
780         updates = []
781         r = dns.res_rec()
782         r.name = "cname_test.%s" % self.get_dns_domain()
783         r.rr_type = dns.DNS_QTYPE_CNAME
784         r.rr_class = dns.DNS_QCLASS_NONE
785         r.ttl = 0
786         r.length = 0xffff
787         r.rdata = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
788         updates.append(r)
789         p.nscount = len(updates)
790         p.nsrecs = updates
791
792         response = self.dns_transaction_udp(p)
793         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
794
795     def test_one_a_query(self):
796         "create a query packet containing one query record"
797         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
798         questions = []
799
800         name = "cname_test.%s" % self.get_dns_domain()
801         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
802         print "asking for ", q.name
803         questions.append(q)
804
805         self.finish_name_packet(p, questions)
806         response = self.dns_transaction_udp(p)
807         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
808         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
809         self.assertEquals(response.ancount, 2)
810         self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
811         self.assertEquals(response.answers[0].rdata, "%s.%s" %
812                           (os.getenv('SERVER'), self.get_dns_domain()))
813         self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
814         self.assertEquals(response.answers[1].rdata,
815                           os.getenv('SERVER_IP'))
816
817 class TestInvalidQueries(DNSTest):
818
819     def test_one_a_query(self):
820         "send 0 bytes follows by create a query packet containing one query record"
821
822         s = None
823         try:
824             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
825             s.connect((os.getenv('SERVER_IP'), 53))
826             s.send("", 0)
827         finally:
828             if s is not None:
829                 s.close()
830
831         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
832         questions = []
833
834         name = "%s.%s" % (os.getenv('SERVER'), self.get_dns_domain())
835         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
836         print "asking for ", q.name
837         questions.append(q)
838
839         self.finish_name_packet(p, questions)
840         response = self.dns_transaction_udp(p)
841         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
842         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
843         self.assertEquals(response.ancount, 1)
844         self.assertEquals(response.answers[0].rdata,
845                           os.getenv('SERVER_IP'))
846
847     def test_one_a_reply(self):
848         "send a reply instead of a query"
849         global timeout
850
851         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
852         questions = []
853
854         name = "%s.%s" % ('fakefakefake', self.get_dns_domain())
855         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
856         print "asking for ", q.name
857         questions.append(q)
858
859         self.finish_name_packet(p, questions)
860         p.operation |= dns.DNS_FLAG_REPLY
861         s = None
862         try:
863             send_packet = ndr.ndr_pack(p)
864             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
865             s.settimeout(timeout)
866             host=os.getenv('SERVER_IP')
867             s.connect((host, 53))
868             tcp_packet = struct.pack('!H', len(send_packet))
869             tcp_packet += send_packet
870             s.send(tcp_packet, 0)
871             recv_packet = s.recv(0xffff + 2, 0)
872             self.assertEquals(0, len(recv_packet))
873         except socket.timeout:
874             # Windows chooses not to respond to incorrectly formatted queries.
875             # Although this appears to be non-deterministic even for the same
876             # request twice, it also appears to be based on a how poorly the
877             # request is formatted.
878             pass
879         finally:
880             if s is not None:
881                 s.close()
882
883 class TestRPCRoundtrip(DNSTest):
884     def get_credentials(self, lp):
885         creds = credentials.Credentials()
886         creds.guess(lp)
887         creds.set_machine_account(lp)
888         creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
889         return creds
890
891     def setUp(self):
892         super(TestRPCRoundtrip, self).setUp()
893         self.lp = self.get_loadparm()
894         self.creds = self.get_credentials(self.lp)
895         self.server = os.getenv("SERVER_IP")
896         self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server),
897                                             self.lp, self.creds)
898
899     def tearDown(self):
900         super(TestRPCRoundtrip, self).tearDown()
901
902     def test_update_add_null_padded_txt_record(self):
903         "test adding records works"
904         prefix, txt = 'pad1textrec', ['"This is a test"', '', '']
905         p = self.make_txt_update(prefix, txt)
906         response = self.dns_transaction_udp(p)
907         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
908         self.check_query_txt(prefix, txt)
909         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
910                              self.get_dns_domain(),
911                              "%s.%s" % (prefix, self.get_dns_domain()),
912                              dnsp.DNS_TYPE_TXT, '"\\"This is a test\\"" "" ""'))
913
914         prefix, txt = 'pad2textrec', ['"This is a test"', '', '', 'more text']
915         p = self.make_txt_update(prefix, txt)
916         response = self.dns_transaction_udp(p)
917         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
918         self.check_query_txt(prefix, txt)
919         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
920                              self.get_dns_domain(),
921                              "%s.%s" % (prefix, self.get_dns_domain()),
922                              dnsp.DNS_TYPE_TXT, '"\\"This is a test\\"" "" "" "more text"'))
923
924         prefix, txt = 'pad3textrec', ['', '', '"This is a test"']
925         p = self.make_txt_update(prefix, txt)
926         response = self.dns_transaction_udp(p)
927         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
928         self.check_query_txt(prefix, txt)
929         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
930                              self.get_dns_domain(),
931                              "%s.%s" % (prefix, self.get_dns_domain()),
932                              dnsp.DNS_TYPE_TXT, '"" "" "\\"This is a test\\""'))
933
934     # Test is incomplete due to strlen against txt records
935     def test_update_add_null_char_txt_record(self):
936         "test adding records works"
937         prefix, txt = 'nulltextrec', ['NULL\x00BYTE']
938         p = self.make_txt_update(prefix, txt)
939         response = self.dns_transaction_udp(p)
940         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
941         self.check_query_txt(prefix, ['NULL'])
942         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
943                              self.get_dns_domain(),
944                              "%s.%s" % (prefix, self.get_dns_domain()),
945                              dnsp.DNS_TYPE_TXT, '"NULL"'))
946
947         prefix, txt = 'nulltextrec2', ['NULL\x00BYTE', 'NULL\x00BYTE']
948         p = self.make_txt_update(prefix, txt)
949         response = self.dns_transaction_udp(p)
950         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
951         self.check_query_txt(prefix, ['NULL', 'NULL'])
952         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
953                              self.get_dns_domain(),
954                              "%s.%s" % (prefix, self.get_dns_domain()),
955                              dnsp.DNS_TYPE_TXT, '"NULL" "NULL"'))
956
957     def test_update_add_hex_char_txt_record(self):
958         "test adding records works"
959         prefix, txt = 'hextextrec', ['HIGH\xFFBYTE']
960         p = self.make_txt_update(prefix, txt)
961         response = self.dns_transaction_udp(p)
962         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
963         self.check_query_txt(prefix, txt)
964         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
965                              self.get_dns_domain(),
966                              "%s.%s" % (prefix, self.get_dns_domain()),
967                              dnsp.DNS_TYPE_TXT, '"HIGH\xFFBYTE"'))
968
969     def test_update_add_slash_txt_record(self):
970         "test adding records works"
971         prefix, txt = 'slashtextrec', ['Th\\=is=is a test']
972         p = self.make_txt_update(prefix, txt)
973         response = self.dns_transaction_udp(p)
974         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
975         self.check_query_txt(prefix, txt)
976         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
977                              self.get_dns_domain(),
978                              "%s.%s" % (prefix, self.get_dns_domain()),
979                              dnsp.DNS_TYPE_TXT, '"Th\\\\=is=is a test"'))
980
981     def test_update_add_two_txt_records(self):
982         "test adding two txt records works"
983         prefix, txt = 'textrec2', ['"This is a test"',
984                                    '"and this is a test, too"']
985         p = self.make_txt_update(prefix, txt)
986         response = self.dns_transaction_udp(p)
987         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
988         self.check_query_txt(prefix, txt)
989         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
990                              self.get_dns_domain(),
991                              "%s.%s" % (prefix, self.get_dns_domain()),
992                              dnsp.DNS_TYPE_TXT, '"\\"This is a test\\""' +
993                              ' "\\"and this is a test, too\\""'))
994
995     def test_update_add_empty_txt_records(self):
996         "test adding two txt records works"
997         prefix, txt = 'emptytextrec', []
998         p = self.make_txt_update(prefix, txt)
999         response = self.dns_transaction_udp(p)
1000         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
1001         self.check_query_txt(prefix, txt)
1002         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server,
1003                              self.get_dns_domain(),
1004                              "%s.%s" % (prefix, self.get_dns_domain()),
1005                              dnsp.DNS_TYPE_TXT, ''))
1006
1007 if __name__ == "__main__":
1008     import unittest
1009     unittest.main()