s4 dns: Fix TCP handling in the DNS server
[ira/wip.git] / source4 / scripting / python / samba / tests / dns.py
1 #!/usr/bin/env python
2
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Kai Blin  <kai@samba.org> 2011
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19
20 import os
21 import sys
22 import struct
23 import random
24 from samba import socket
25 import samba.ndr as ndr
26 import samba.dcerpc.dns as dns
27 from samba.tests import TestCase
28
29 class DNSTest(TestCase):
30
31     def errstr(self, errcode):
32         "Return a readable error code"
33         string_codes = [
34             "OK",
35             "FORMERR",
36             "SERVFAIL",
37             "NXDOMAIN",
38             "NOTIMP",
39             "REFUSED",
40             "YXDOMAIN",
41             "YXRRSET",
42             "NXRRSET",
43             "NOTAUTH",
44             "NOTZONE",
45         ]
46
47         return string_codes[errcode]
48
49
50     def assert_dns_rcode_equals(self, packet, rcode):
51         "Helper function to check return code"
52         p_errcode = packet.operation & 0x000F
53         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" % \
54                             (self.errstr(rcode), self.errstr(p_errcode)))
55
56     def assert_dns_opcode_equals(self, packet, opcode):
57         "Helper function to check opcode"
58         p_opcode = packet.operation & 0x7800
59         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" % \
60                             (opcode, p_opcode))
61
62     def make_name_packet(self, opcode, qid=None):
63         "Helper creating a dns.name_packet"
64         p = dns.name_packet()
65         if qid is None:
66             p.id = random.randint(0x0, 0xffff)
67         p.operation = opcode
68         p.questions = []
69         return p
70
71     def finish_name_packet(self, packet, questions):
72         "Helper to finalize a dns.name_packet"
73         packet.qdcount = len(questions)
74         packet.questions = questions
75
76     def make_name_question(self, name, qtype, qclass):
77         "Helper creating a dns.name_question"
78         q = dns.name_question()
79         q.name = name
80         q.question_type = qtype
81         q.question_class = qclass
82         return q
83
84     def get_dns_domain(self):
85         "Helper to get dns domain"
86         return os.getenv('REALM', 'example.com').lower()
87
88     def dns_transaction_udp(self, packet, host=os.getenv('DC_SERVER_IP')):
89         "send a DNS query and read the reply"
90         s = None
91         try:
92             send_packet = ndr.ndr_pack(packet)
93             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
94             s.connect((host, 53))
95             s.send(send_packet, 0)
96             recv_packet = s.recv(2048, 0)
97             return ndr.ndr_unpack(dns.name_packet, recv_packet)
98         finally:
99             if s is not None:
100                 s.close()
101
102     def dns_transaction_tcp(self, packet, host=os.getenv('DC_SERVER_IP')):
103         "send a DNS query and read the reply"
104         s = None
105         try:
106             send_packet = ndr.ndr_pack(packet)
107             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
108             s.connect((host, 53))
109             tcp_packet = struct.pack('!H', len(send_packet))
110             tcp_packet += send_packet
111             s.send(tcp_packet, 0)
112             recv_packet = s.recv(0xffff + 2, 0)
113             return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
114         finally:
115                 if s is not None:
116                     s.close()
117
118     def test_one_a_query(self):
119         "create a query packet containing one query record"
120         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
121         questions = []
122
123         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
124         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
125         print "asking for ", q.name
126         questions.append(q)
127
128         self.finish_name_packet(p, questions)
129         response = self.dns_transaction_udp(p)
130         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
131         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
132         self.assertEquals(response.ancount, 1)
133         self.assertEquals(response.answers[0].rdata,
134                           os.getenv('DC_SERVER_IP'))
135
136     def test_one_a_query_tcp(self):
137         "create a query packet containing one query record via TCP"
138         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
139         questions = []
140
141         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
142         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
143         print "asking for ", q.name
144         questions.append(q)
145
146         self.finish_name_packet(p, questions)
147         response = self.dns_transaction_tcp(p)
148         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
149         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
150         self.assertEquals(response.ancount, 1)
151         self.assertEquals(response.answers[0].rdata,
152                           os.getenv('DC_SERVER_IP'))
153
154     def test_two_queries(self):
155         "create a query packet containing two query records"
156         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
157         questions = []
158
159         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
160         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
161         questions.append(q)
162
163         name = "%s.%s" % ('bogusname', self.get_dns_domain())
164         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
165         questions.append(q)
166
167         self.finish_name_packet(p, questions)
168         response = self.dns_transaction_udp(p)
169         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
170
171     def test_qtype_all_query(self):
172         "create a QTYPE_ALL query"
173         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
174         questions = []
175
176         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
177         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_IN)
178         print "asking for ", q.name
179         questions.append(q)
180
181         self.finish_name_packet(p, questions)
182         response = self.dns_transaction_udp(p)
183
184         num_answers = 1
185         dc_ipv6 = os.getenv('DC_SERVER_IPV6')
186         if dc_ipv6 is not None:
187             num_answers += 1
188
189         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
190         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
191         self.assertEquals(response.ancount, num_answers)
192         self.assertEquals(response.answers[0].rdata,
193                           os.getenv('DC_SERVER_IP'))
194         if dc_ipv6 is not None:
195             self.assertEquals(response.answers[1].rdata, dc_ipv6)
196
197     def test_qclass_none_query(self):
198         "create a QCLASS_NONE query"
199         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
200         questions = []
201
202         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
203         q = self.make_name_question(name, dns.DNS_QTYPE_ALL, dns.DNS_QCLASS_NONE)
204         questions.append(q)
205
206         self.finish_name_packet(p, questions)
207         response = self.dns_transaction_udp(p)
208         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
209
210 # Only returns an authority section entry in BIND and Win DNS
211 # FIXME: Enable one Samba implements this feature
212 #    def test_soa_hostname_query(self):
213 #        "create a SOA query for a hostname"
214 #        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
215 #        questions = []
216 #
217 #        name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
218 #        q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
219 #        questions.append(q)
220 #
221 #        self.finish_name_packet(p, questions)
222 #        response = self.dns_transaction_udp(p)
223 #        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
224 #        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
225 #        # We don't get SOA records for single hosts
226 #        self.assertEquals(response.ancount, 0)
227
228     def test_soa_domain_query(self):
229         "create a SOA query for a domain"
230         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
231         questions = []
232
233         name = self.get_dns_domain()
234         q = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
235         questions.append(q)
236
237         self.finish_name_packet(p, questions)
238         response = self.dns_transaction_udp(p)
239         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
240         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
241         self.assertEquals(response.ancount, 1)
242
243     def test_two_updates(self):
244         "create two update requests"
245         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
246         updates = []
247
248         name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
249         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
250         updates.append(u)
251
252         name = self.get_dns_domain()
253         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
254         updates.append(u)
255
256         self.finish_name_packet(p, updates)
257         response = self.dns_transaction_udp(p)
258         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
259
260     def test_update_wrong_qclass(self):
261         "create update with DNS_QCLASS_NONE"
262         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
263         updates = []
264
265         name = self.get_dns_domain()
266         u = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_NONE)
267         updates.append(u)
268
269         self.finish_name_packet(p, updates)
270         response = self.dns_transaction_udp(p)
271         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
272
273     def test_update_prereq_with_non_null_ttl(self):
274         "test update with a non-null TTL"
275         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
276         updates = []
277
278         name = self.get_dns_domain()
279
280         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
281         updates.append(u)
282         self.finish_name_packet(p, updates)
283
284         prereqs = []
285         r = dns.res_rec()
286         r.name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
287         r.rr_type = dns.DNS_QTYPE_TXT
288         r.rr_class = dns.DNS_QCLASS_NONE
289         r.ttl = 1
290         r.length = 0
291         prereqs.append(r)
292
293         p.ancount = len(prereqs)
294         p.answers = prereqs
295
296         response = self.dns_transaction_udp(p)
297         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
298
299 # I'd love to test this one, but it segfaults. :)
300 #    def test_update_prereq_with_non_null_length(self):
301 #        "test update with a non-null length"
302 #        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
303 #        updates = []
304 #
305 #        name = self.get_dns_domain()
306 #
307 #        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
308 #        updates.append(u)
309 #        self.finish_name_packet(p, updates)
310 #
311 #        prereqs = []
312 #        r = dns.res_rec()
313 #        r.name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
314 #        r.rr_type = dns.DNS_QTYPE_TXT
315 #        r.rr_class = dns.DNS_QCLASS_ANY
316 #        r.ttl = 0
317 #        r.length = 1
318 #        prereqs.append(r)
319 #
320 #        p.ancount = len(prereqs)
321 #        p.answers = prereqs
322 #
323 #        response = self.dns_transaction_udp(p)
324 #        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
325
326     def test_update_prereq_nonexisting_name(self):
327         "test update with a non-null TTL"
328         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
329         updates = []
330
331         name = self.get_dns_domain()
332
333         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
334         updates.append(u)
335         self.finish_name_packet(p, updates)
336
337         prereqs = []
338         r = dns.res_rec()
339         r.name = "idontexist.%s" % self.get_dns_domain()
340         r.rr_type = dns.DNS_QTYPE_TXT
341         r.rr_class = dns.DNS_QCLASS_ANY
342         r.ttl = 0
343         r.length = 0
344         prereqs.append(r)
345
346         p.ancount = len(prereqs)
347         p.answers = prereqs
348
349         response = self.dns_transaction_udp(p)
350         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
351
352 if __name__ == "__main__":
353     import unittest
354     unittest.main()