ab00e631772c99b0899f141a8c34865afc104155
[samba.git] / python / samba / tests / dns_base.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin  <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme  <slow@samba.org> 2016
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 from samba.tests import TestCaseInTempDir
20 from samba.dcerpc import dns, dnsp
21 from samba import gensec, tests
22 from samba import credentials
23 import struct
24 import samba.ndr as ndr
25 import random
26 import socket
27 import uuid
28 import time
29
30
31 class DNSTest(TestCaseInTempDir):
32
33     def setUp(self):
34         super(DNSTest, self).setUp()
35         self.timeout = None
36
37     def errstr(self, errcode):
38         "Return a readable error code"
39         string_codes = [
40             "OK",
41             "FORMERR",
42             "SERVFAIL",
43             "NXDOMAIN",
44             "NOTIMP",
45             "REFUSED",
46             "YXDOMAIN",
47             "YXRRSET",
48             "NXRRSET",
49             "NOTAUTH",
50             "NOTZONE",
51             "0x0B",
52             "0x0C",
53             "0x0D",
54             "0x0E",
55             "0x0F",
56             "BADSIG",
57             "BADKEY"
58         ]
59
60         return string_codes[errcode]
61
62     def assert_rcode_equals(self, rcode, expected):
63         "Helper function to check return code"
64         self.assertEqual(rcode, expected, "Expected RCODE %s, got %s" %
65                           (self.errstr(expected), self.errstr(rcode)))
66
67     def assert_dns_rcode_equals(self, packet, rcode):
68         "Helper function to check return code"
69         p_errcode = packet.operation & dns.DNS_RCODE
70         self.assertEqual(p_errcode, rcode, "Expected RCODE %s, got %s" %
71                           (self.errstr(rcode), self.errstr(p_errcode)))
72
73     def assert_dns_opcode_equals(self, packet, opcode):
74         "Helper function to check opcode"
75         p_opcode = packet.operation & dns.DNS_OPCODE
76         self.assertEqual(p_opcode, opcode, "Expected OPCODE %s, got %s" %
77                           (opcode, p_opcode))
78
79     def make_name_packet(self, opcode, qid=None):
80         "Helper creating a dns.name_packet"
81         p = dns.name_packet()
82         if qid is None:
83             p.id = random.randint(0x0, 0xff00)
84         p.operation = opcode
85         p.questions = []
86         p.additional = []
87         return p
88
89     def finish_name_packet(self, packet, questions):
90         "Helper to finalize a dns.name_packet"
91         packet.qdcount = len(questions)
92         packet.questions = questions
93
94     def make_name_question(self, name, qtype, qclass):
95         "Helper creating a dns.name_question"
96         q = dns.name_question()
97         q.name = name
98         q.question_type = qtype
99         q.question_class = qclass
100         return q
101
102     def make_txt_record(self, records):
103         rdata_txt = dns.txt_record()
104         s_list = dnsp.string_list()
105         s_list.count = len(records)
106         s_list.str = records
107         rdata_txt.txt = s_list
108         return rdata_txt
109
110     def get_dns_domain(self):
111         "Helper to get dns domain"
112         return self.creds.get_realm().lower()
113
114     def dns_transaction_udp(self, packet, host,
115                             dump=False, timeout=None):
116         "send a DNS query and read the reply"
117         s = None
118         if timeout is None:
119             timeout = self.timeout
120         try:
121             send_packet = ndr.ndr_pack(packet)
122             if dump:
123                 print(self.hexdump(send_packet))
124             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
125             s.settimeout(timeout)
126             s.connect((host, 53))
127             s.sendall(send_packet, 0)
128             recv_packet = s.recv(2048, 0)
129             if dump:
130                 print(self.hexdump(recv_packet))
131             response = ndr.ndr_unpack(dns.name_packet, recv_packet)
132             return (response, recv_packet)
133         finally:
134             if s is not None:
135                 s.close()
136
137     def dns_transaction_tcp(self, packet, host,
138                             dump=False, timeout=None):
139         "send a DNS query and read the reply, also return the raw packet"
140         s = None
141         if timeout is None:
142             timeout = self.timeout
143         try:
144             send_packet = ndr.ndr_pack(packet)
145             if dump:
146                 print(self.hexdump(send_packet))
147             s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
148             s.settimeout(timeout)
149             s.connect((host, 53))
150             tcp_packet = struct.pack('!H', len(send_packet))
151             tcp_packet += send_packet
152             s.sendall(tcp_packet)
153
154             recv_packet = s.recv(0xffff + 2, 0)
155             if dump:
156                 print(self.hexdump(recv_packet))
157             response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
158
159         finally:
160             if s is not None:
161                 s.close()
162
163         # unpacking and packing again should produce same bytestream
164         my_packet = ndr.ndr_pack(response)
165         self.assertEqual(my_packet, recv_packet[2:])
166         return (response, recv_packet[2:])
167
168     def make_txt_update(self, prefix, txt_array, zone=None, ttl=900):
169         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
170         updates = []
171
172         name = zone or self.get_dns_domain()
173         u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
174         updates.append(u)
175         self.finish_name_packet(p, updates)
176
177         updates = []
178         r = dns.res_rec()
179         r.name = "%s.%s" % (prefix, name)
180         r.rr_type = dns.DNS_QTYPE_TXT
181         r.rr_class = dns.DNS_QCLASS_IN
182         r.ttl = ttl
183         r.length = 0xffff
184         rdata = self.make_txt_record(txt_array)
185         r.rdata = rdata
186         updates.append(r)
187         p.nscount = len(updates)
188         p.nsrecs = updates
189
190         return p
191
192     def check_query_txt(self, prefix, txt_array, zone=None):
193         name = "%s.%s" % (prefix, zone or self.get_dns_domain())
194         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
195         questions = []
196
197         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
198         questions.append(q)
199
200         self.finish_name_packet(p, questions)
201         (response, response_packet) =\
202             self.dns_transaction_udp(p, host=self.server_ip)
203         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
204         self.assertEqual(response.ancount, 1)
205         self.assertEqual(response.answers[0].rdata.txt.str, txt_array)
206
207
208 class DNSTKeyTest(DNSTest):
209     def setUp(self):
210         super(DNSTKeyTest, self).setUp()
211         self.settings = {}
212         self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
213         self.settings["target_hostname"] = self.server
214
215         self.creds = credentials.Credentials()
216         self.creds.guess(self.lp_ctx)
217         self.creds.set_username(tests.env_get_var_value('USERNAME'))
218         self.creds.set_password(tests.env_get_var_value('PASSWORD'))
219         self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
220         self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
221
222     def tkey_trans(self, creds=None):
223         "Do a TKEY transaction and establish a gensec context"
224
225         if creds is None:
226             creds = self.creds
227
228         self.key_name = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
229
230         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
231         q = self.make_name_question(self.key_name,
232                                     dns.DNS_QTYPE_TKEY,
233                                     dns.DNS_QCLASS_IN)
234         questions = []
235         questions.append(q)
236         self.finish_name_packet(p, questions)
237
238         r = dns.res_rec()
239         r.name = self.key_name
240         r.rr_type = dns.DNS_QTYPE_TKEY
241         r.rr_class = dns.DNS_QCLASS_IN
242         r.ttl = 0
243         r.length = 0xffff
244         rdata = dns.tkey_record()
245         rdata.algorithm = "gss-tsig"
246         rdata.inception = int(time.time())
247         rdata.expiration = int(time.time()) + 60 * 60
248         rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
249         rdata.error = 0
250         rdata.other_size = 0
251
252         self.g = gensec.Security.start_client(self.settings)
253         self.g.set_credentials(creds)
254         self.g.set_target_service("dns")
255         self.g.set_target_hostname(self.server)
256         self.g.want_feature(gensec.FEATURE_SIGN)
257         self.g.start_mech_by_name("spnego")
258
259         finished = False
260         client_to_server = b""
261
262         (finished, server_to_client) = self.g.update(client_to_server)
263         self.assertFalse(finished)
264
265         data = [x if isinstance(x, int) else ord(x) for x in list(server_to_client)]
266         rdata.key_data = data
267         rdata.key_size = len(data)
268         r.rdata = rdata
269
270         additional = [r]
271         p.arcount = 1
272         p.additional = additional
273
274         (response, response_packet) =\
275             self.dns_transaction_tcp(p, self.server_ip)
276         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
277
278         tkey_record = response.answers[0].rdata
279         server_to_client = bytes(tkey_record.key_data)
280         (finished, client_to_server) = self.g.update(server_to_client)
281         self.assertTrue(finished)
282
283         self.verify_packet(response, response_packet)
284
285     def verify_packet(self, response, response_packet, request_mac=b""):
286         self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
287
288         tsig_record = response.additional[0].rdata
289         mac = bytes(tsig_record.mac)
290
291         # Cut off tsig record from dns response packet for MAC verification
292         # and reset additional record count.
293         key_name_len = len(self.key_name) + 2
294         tsig_record_len = len(ndr.ndr_pack(tsig_record)) + key_name_len + 10
295
296         # convert str/bytes to a list (of string char or int)
297         # so it can be modified
298         response_packet_list = [x if isinstance(x, int) else ord(x) for x in response_packet]
299         del response_packet_list[-tsig_record_len:]
300         response_packet_list[11] = 0
301
302         # convert modified list (of string char or int) to str/bytes
303         response_packet_wo_tsig = bytes(response_packet_list)
304
305         fake_tsig = dns.fake_tsig_rec()
306         fake_tsig.name = self.key_name
307         fake_tsig.rr_class = dns.DNS_QCLASS_ANY
308         fake_tsig.ttl = 0
309         fake_tsig.time_prefix = tsig_record.time_prefix
310         fake_tsig.time = tsig_record.time
311         fake_tsig.algorithm_name = tsig_record.algorithm_name
312         fake_tsig.fudge = tsig_record.fudge
313         fake_tsig.error = 0
314         fake_tsig.other_size = 0
315         fake_tsig_packet = ndr.ndr_pack(fake_tsig)
316
317         data = request_mac + response_packet_wo_tsig + fake_tsig_packet
318         self.g.check_packet(data, data, mac)
319
320     def sign_packet(self, packet, key_name):
321         "Sign a packet, calculate a MAC and add TSIG record"
322         packet_data = ndr.ndr_pack(packet)
323
324         fake_tsig = dns.fake_tsig_rec()
325         fake_tsig.name = key_name
326         fake_tsig.rr_class = dns.DNS_QCLASS_ANY
327         fake_tsig.ttl = 0
328         fake_tsig.time_prefix = 0
329         fake_tsig.time = int(time.time())
330         fake_tsig.algorithm_name = "gss-tsig"
331         fake_tsig.fudge = 300
332         fake_tsig.error = 0
333         fake_tsig.other_size = 0
334         fake_tsig_packet = ndr.ndr_pack(fake_tsig)
335
336         data = packet_data + fake_tsig_packet
337         mac = self.g.sign_packet(data, data)
338         mac_list = [x if isinstance(x, int) else ord(x) for x in list(mac)]
339
340         rdata = dns.tsig_record()
341         rdata.algorithm_name = "gss-tsig"
342         rdata.time_prefix = 0
343         rdata.time = fake_tsig.time
344         rdata.fudge = 300
345         rdata.original_id = packet.id
346         rdata.error = 0
347         rdata.other_size = 0
348         rdata.mac = mac_list
349         rdata.mac_size = len(mac_list)
350
351         r = dns.res_rec()
352         r.name = key_name
353         r.rr_type = dns.DNS_QTYPE_TSIG
354         r.rr_class = dns.DNS_QCLASS_ANY
355         r.ttl = 0
356         r.length = 0xffff
357         r.rdata = rdata
358
359         additional = [r]
360         packet.additional = additional
361         packet.arcount = 1
362
363         return mac
364
365     def bad_sign_packet(self, packet, key_name):
366         """Add bad signature for a packet by bitflipping
367         the final byte in the MAC"""
368
369         mac_list = [x if isinstance(x, int) else ord(x) for x in list("badmac")]
370
371         rdata = dns.tsig_record()
372         rdata.algorithm_name = "gss-tsig"
373         rdata.time_prefix = 0
374         rdata.time = int(time.time())
375         rdata.fudge = 300
376         rdata.original_id = packet.id
377         rdata.error = 0
378         rdata.other_size = 0
379         rdata.mac = mac_list
380         rdata.mac_size = len(mac_list)
381
382         r = dns.res_rec()
383         r.name = key_name
384         r.rr_type = dns.DNS_QTYPE_TSIG
385         r.rr_class = dns.DNS_QCLASS_ANY
386         r.ttl = 0
387         r.length = 0xffff
388         r.rdata = rdata
389
390         additional = [r]
391         packet.additional = additional
392         packet.arcount = 1
393
394     def search_record(self, name):
395         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
396         questions = []
397
398         q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
399         questions.append(q)
400
401         self.finish_name_packet(p, questions)
402         (response, response_packet) =\
403             self.dns_transaction_udp(p, self.server_ip)
404         return response.operation & 0x000F
405
406     def make_update_request(self, delete=False):
407         "Create a DNS update request"
408
409         rr_class = dns.DNS_QCLASS_IN
410         ttl = 900
411
412         if delete:
413             rr_class = dns.DNS_QCLASS_NONE
414             ttl = 0
415
416         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
417         q = self.make_name_question(self.get_dns_domain(),
418                                     dns.DNS_QTYPE_SOA,
419                                     dns.DNS_QCLASS_IN)
420         questions = []
421         questions.append(q)
422         self.finish_name_packet(p, questions)
423
424         updates = []
425         r = dns.res_rec()
426         r.name = self.newrecname
427         r.rr_type = dns.DNS_QTYPE_TXT
428         r.rr_class = rr_class
429         r.ttl = ttl
430         r.length = 0xffff
431         rdata = self.make_txt_record(['"This is a test"'])
432         r.rdata = rdata
433         updates.append(r)
434         p.nscount = len(updates)
435         p.nsrecs = updates
436
437         return p