PEP8: fix E302: expected 2 blank lines, found 1
[bbaumbach/samba-autobuild/.git] / python / samba / tests / dns_base.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin  <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme  <slow@samba.org> 2016
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 from __future__ import print_function
20 from samba.tests import TestCaseInTempDir
21 from samba.dcerpc import dns, dnsp
22 from samba import gensec, tests
23 from samba import credentials
24 import struct
25 import samba.ndr as ndr
26 import random
27 import socket
28 import uuid
29 import time
30 from samba.compat import binary_type
31
32
33 class DNSTest(TestCaseInTempDir):
34
35     def setUp(self):
36         super(DNSTest, self).setUp()
37         self.timeout = None
38
39     def errstr(self, errcode):
40         "Return a readable error code"
41         string_codes = [
42             "OK",
43             "FORMERR",
44             "SERVFAIL",
45             "NXDOMAIN",
46             "NOTIMP",
47             "REFUSED",
48             "YXDOMAIN",
49             "YXRRSET",
50             "NXRRSET",
51             "NOTAUTH",
52             "NOTZONE",
53             "0x0B",
54             "0x0C",
55             "0x0D",
56             "0x0E",
57             "0x0F",
58             "BADSIG",
59             "BADKEY"
60         ]
61
62         return string_codes[errcode]
63
64     def assert_rcode_equals(self, rcode, expected):
65         "Helper function to check return code"
66         self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
67                           (self.errstr(expected), self.errstr(rcode)))
68
69     def assert_dns_rcode_equals(self, packet, rcode):
70         "Helper function to check return code"
71         p_errcode = packet.operation & 0x000F
72         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
73                           (self.errstr(rcode), self.errstr(p_errcode)))
74
75     def assert_dns_opcode_equals(self, packet, opcode):
76         "Helper function to check opcode"
77         p_opcode = packet.operation & 0x7800
78         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
79                           (opcode, p_opcode))
80
81     def make_name_packet(self, opcode, qid=None):
82         "Helper creating a dns.name_packet"
83         p = dns.name_packet()
84         if qid is None:
85             p.id = random.randint(0x0, 0xff00)
86         p.operation = opcode
87         p.questions = []
88         p.additional = []
89         return p
90
91     def finish_name_packet(self, packet, questions):
92         "Helper to finalize a dns.name_packet"
93         packet.qdcount = len(questions)
94         packet.questions = questions
95
96     def make_name_question(self, name, qtype, qclass):
97         "Helper creating a dns.name_question"
98         q = dns.name_question()
99         q.name = name
100         q.question_type = qtype
101         q.question_class = qclass
102         return q
103
104     def make_txt_record(self, records):
105         rdata_txt = dns.txt_record()
106         s_list = dnsp.string_list()
107         s_list.count = len(records)
108         s_list.str = records
109         rdata_txt.txt = s_list
110         return rdata_txt
111
112     def get_dns_domain(self):
113         "Helper to get dns domain"
114         return self.creds.get_realm().lower()
115
116     def dns_transaction_udp(self, packet, host,
117                             dump=False, timeout=None):
118         "send a DNS query and read the reply"
119         s = None
120         if timeout is None:
121             timeout = self.timeout
122         try:
123             send_packet = ndr.ndr_pack(packet)
124             if dump:
125                 print(self.hexdump(send_packet))
126             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
127             s.settimeout(timeout)
128             s.connect((host, 53))
129             s.sendall(send_packet, 0)
130             recv_packet = s.recv(2048, 0)
131             if dump:
132                 print(self.hexdump(recv_packet))
133             response = ndr.ndr_unpack(dns.name_packet, recv_packet)
134             return (response, recv_packet)
135         finally:
136             if s is not None:
137                 s.close()
138
139     def dns_transaction_tcp(self, packet, host,
140                             dump=False, timeout=None):
141         "send a DNS query and read the reply, also return the raw packet"
142         s = None
143         if timeout is None:
144             timeout = self.timeout
145         try:
146             send_packet = ndr.ndr_pack(packet)
147             if dump:
148                 print(self.hexdump(send_packet))
149             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
150             s.settimeout(timeout)
151             s.connect((host, 53))
152             tcp_packet = struct.pack('!H', len(send_packet))
153             tcp_packet += send_packet
154             s.sendall(tcp_packet)
155
156             recv_packet = s.recv(0xffff + 2, 0)
157             if dump:
158                 print(self.hexdump(recv_packet))
159             response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
160
161         finally:
162             if s is not None:
163                 s.close()
164
165         # unpacking and packing again should produce same bytestream
166         my_packet = ndr.ndr_pack(response)
167         self.assertEquals(my_packet, recv_packet[2:])
168         return (response, recv_packet[2:])
169
170     def make_txt_update(self, prefix, txt_array, domain=None):
171         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
172         updates = []
173
174         name = domain or self.get_dns_domain()
175         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
176         updates.append(u)
177         self.finish_name_packet(p, updates)
178
179         updates = []
180         r = dns.res_rec()
181         r.name = "%s.%s" % (prefix, name)
182         r.rr_type = dns.DNS_QTYPE_TXT
183         r.rr_class = dns.DNS_QCLASS_IN
184         r.ttl = 900
185         r.length = 0xffff
186         rdata = self.make_txt_record(txt_array)
187         r.rdata = rdata
188         updates.append(r)
189         p.nscount = len(updates)
190         p.nsrecs = updates
191
192         return p
193
194     def check_query_txt(self, prefix, txt_array, zone=None):
195         name = "%s.%s" % (prefix, zone or self.get_dns_domain())
196         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
197         questions = []
198
199         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
200         questions.append(q)
201
202         self.finish_name_packet(p, questions)
203         (response, response_packet) =\
204             self.dns_transaction_udp(p, host=self.server_ip)
205         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
206         self.assertEquals(response.ancount, 1)
207         self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
208
209
210 class DNSTKeyTest(DNSTest):
211     def setUp(self):
212         super(DNSTKeyTest, self).setUp()
213         self.settings = {}
214         self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
215         self.settings["target_hostname"] = self.server
216
217         self.creds = credentials.Credentials()
218         self.creds.guess(self.lp_ctx)
219         self.creds.set_username(tests.env_get_var_value('USERNAME'))
220         self.creds.set_password(tests.env_get_var_value('PASSWORD'))
221         self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
222         self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
223
224     def tkey_trans(self, creds=None):
225         "Do a TKEY transaction and establish a gensec context"
226
227         if creds is None:
228             creds = self.creds
229
230         self.key_name = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
231
232         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
233         q = self.make_name_question(self.key_name,
234                                     dns.DNS_QTYPE_TKEY,
235                                     dns.DNS_QCLASS_IN)
236         questions = []
237         questions.append(q)
238         self.finish_name_packet(p, questions)
239
240         r = dns.res_rec()
241         r.name = self.key_name
242         r.rr_type = dns.DNS_QTYPE_TKEY
243         r.rr_class = dns.DNS_QCLASS_IN
244         r.ttl = 0
245         r.length = 0xffff
246         rdata = dns.tkey_record()
247         rdata.algorithm = "gss-tsig"
248         rdata.inception = int(time.time())
249         rdata.expiration = int(time.time()) + 60 * 60
250         rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
251         rdata.error = 0
252         rdata.other_size = 0
253
254         self.g = gensec.Security.start_client(self.settings)
255         self.g.set_credentials(creds)
256         self.g.set_target_service("dns")
257         self.g.set_target_hostname(self.server)
258         self.g.want_feature(gensec.FEATURE_SIGN)
259         self.g.start_mech_by_name("spnego")
260
261         finished = False
262         client_to_server = ""
263
264         (finished, server_to_client) = self.g.update(client_to_server)
265         self.assertFalse(finished)
266
267         data = [x if isinstance(x, int) else ord(x) for x in list(server_to_client)]
268         rdata.key_data = data
269         rdata.key_size = len(data)
270         r.rdata = rdata
271
272         additional = [r]
273         p.arcount = 1
274         p.additional = additional
275
276         (response, response_packet) =\
277             self.dns_transaction_tcp(p, self.server_ip)
278         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
279
280         tkey_record = response.answers[0].rdata
281         server_to_client = binary_type(bytearray(tkey_record.key_data))
282         (finished, client_to_server) = self.g.update(server_to_client)
283         self.assertTrue(finished)
284
285         self.verify_packet(response, response_packet)
286
287     def verify_packet(self, response, response_packet, request_mac=b""):
288         self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
289
290         tsig_record = response.additional[0].rdata
291         mac = binary_type(bytearray(tsig_record.mac))
292
293         # Cut off tsig record from dns response packet for MAC verification
294         # and reset additional record count.
295         key_name_len = len(self.key_name) + 2
296         tsig_record_len = len(ndr.ndr_pack(tsig_record)) + key_name_len + 10
297
298         # convert str/bytes to a list (of string char or int)
299         # so it can be modified
300         response_packet_list = [x if isinstance(x, int) else ord(x) for x in response_packet]
301         del response_packet_list[-tsig_record_len:]
302         if isinstance(response_packet_list[11], int):
303             response_packet_list[11] = 0
304         else:
305             response_packet_list[11] = chr(0)
306
307         # convert modified list (of string char or int) to str/bytes
308         response_packet_wo_tsig = binary_type(bytearray(response_packet_list))
309
310         fake_tsig = dns.fake_tsig_rec()
311         fake_tsig.name = self.key_name
312         fake_tsig.rr_class = dns.DNS_QCLASS_ANY
313         fake_tsig.ttl = 0
314         fake_tsig.time_prefix = tsig_record.time_prefix
315         fake_tsig.time = tsig_record.time
316         fake_tsig.algorithm_name = tsig_record.algorithm_name
317         fake_tsig.fudge = tsig_record.fudge
318         fake_tsig.error = 0
319         fake_tsig.other_size = 0
320         fake_tsig_packet = ndr.ndr_pack(fake_tsig)
321
322         data = request_mac + response_packet_wo_tsig + fake_tsig_packet
323         self.g.check_packet(data, data, mac)
324
325     def sign_packet(self, packet, key_name):
326         "Sign a packet, calculate a MAC and add TSIG record"
327         packet_data = ndr.ndr_pack(packet)
328
329         fake_tsig = dns.fake_tsig_rec()
330         fake_tsig.name = key_name
331         fake_tsig.rr_class = dns.DNS_QCLASS_ANY
332         fake_tsig.ttl = 0
333         fake_tsig.time_prefix = 0
334         fake_tsig.time = int(time.time())
335         fake_tsig.algorithm_name = "gss-tsig"
336         fake_tsig.fudge = 300
337         fake_tsig.error = 0
338         fake_tsig.other_size = 0
339         fake_tsig_packet = ndr.ndr_pack(fake_tsig)
340
341         data = packet_data + fake_tsig_packet
342         mac = self.g.sign_packet(data, data)
343         mac_list = [x if isinstance(x, int) else ord(x) for x in list(mac)]
344
345         rdata = dns.tsig_record()
346         rdata.algorithm_name = "gss-tsig"
347         rdata.time_prefix = 0
348         rdata.time = fake_tsig.time
349         rdata.fudge = 300
350         rdata.original_id = packet.id
351         rdata.error = 0
352         rdata.other_size = 0
353         rdata.mac = mac_list
354         rdata.mac_size = len(mac_list)
355
356         r = dns.res_rec()
357         r.name = key_name
358         r.rr_type = dns.DNS_QTYPE_TSIG
359         r.rr_class = dns.DNS_QCLASS_ANY
360         r.ttl = 0
361         r.length = 0xffff
362         r.rdata = rdata
363
364         additional = [r]
365         packet.additional = additional
366         packet.arcount = 1
367
368         return mac
369
370     def bad_sign_packet(self, packet, key_name):
371         '''Add bad signature for a packet by bitflipping
372         the final byte in the MAC'''
373
374         mac_list = [x if isinstance(x, int) else ord(x) for x in list("badmac")]
375
376         rdata = dns.tsig_record()
377         rdata.algorithm_name = "gss-tsig"
378         rdata.time_prefix = 0
379         rdata.time = int(time.time())
380         rdata.fudge = 300
381         rdata.original_id = packet.id
382         rdata.error = 0
383         rdata.other_size = 0
384         rdata.mac = mac_list
385         rdata.mac_size = len(mac_list)
386
387         r = dns.res_rec()
388         r.name = key_name
389         r.rr_type = dns.DNS_QTYPE_TSIG
390         r.rr_class = dns.DNS_QCLASS_ANY
391         r.ttl = 0
392         r.length = 0xffff
393         r.rdata = rdata
394
395         additional = [r]
396         packet.additional = additional
397         packet.arcount = 1
398
399     def search_record(self, name):
400         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
401         questions = []
402
403         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
404         questions.append(q)
405
406         self.finish_name_packet(p, questions)
407         (response, response_packet) =\
408             self.dns_transaction_udp(p, self.server_ip)
409         return response.operation & 0x000F
410
411     def make_update_request(self, delete=False):
412         "Create a DNS update request"
413
414         rr_class = dns.DNS_QCLASS_IN
415         ttl = 900
416
417         if delete:
418             rr_class = dns.DNS_QCLASS_NONE
419             ttl = 0
420
421         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
422         q = self.make_name_question(self.get_dns_domain(),
423                                     dns.DNS_QTYPE_SOA,
424                                     dns.DNS_QCLASS_IN)
425         questions = []
426         questions.append(q)
427         self.finish_name_packet(p, questions)
428
429         updates = []
430         r = dns.res_rec()
431         r.name = self.newrecname
432         r.rr_type = dns.DNS_QTYPE_TXT
433         r.rr_class = rr_class
434         r.ttl = ttl
435         r.length = 0xffff
436         rdata = self.make_txt_record(['"This is a test"'])
437         r.rdata = rdata
438         updates.append(r)
439         p.nscount = len(updates)
440         p.nsrecs = updates
441
442         return p