1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme <slow@samba.org> 2016
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function
20 from samba.tests import TestCaseInTempDir
21 from samba.dcerpc import dns, dnsp
22 from samba import gensec, tests
23 from samba import credentials
25 import samba.ndr as ndr
30 from samba.compat import binary_type
33 class DNSTest(TestCaseInTempDir):
36 super(DNSTest, self).setUp()
39 def errstr(self, errcode):
40 "Return a readable error code"
62 return string_codes[errcode]
64 def assert_rcode_equals(self, rcode, expected):
65 "Helper function to check return code"
66 self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
67 (self.errstr(expected), self.errstr(rcode)))
69 def assert_dns_rcode_equals(self, packet, rcode):
70 "Helper function to check return code"
71 p_errcode = packet.operation & 0x000F
72 self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
73 (self.errstr(rcode), self.errstr(p_errcode)))
75 def assert_dns_opcode_equals(self, packet, opcode):
76 "Helper function to check opcode"
77 p_opcode = packet.operation & 0x7800
78 self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
81 def make_name_packet(self, opcode, qid=None):
82 "Helper creating a dns.name_packet"
85 p.id = random.randint(0x0, 0xff00)
91 def finish_name_packet(self, packet, questions):
92 "Helper to finalize a dns.name_packet"
93 packet.qdcount = len(questions)
94 packet.questions = questions
96 def make_name_question(self, name, qtype, qclass):
97 "Helper creating a dns.name_question"
98 q = dns.name_question()
100 q.question_type = qtype
101 q.question_class = qclass
104 def make_txt_record(self, records):
105 rdata_txt = dns.txt_record()
106 s_list = dnsp.string_list()
107 s_list.count = len(records)
109 rdata_txt.txt = s_list
112 def get_dns_domain(self):
113 "Helper to get dns domain"
114 return self.creds.get_realm().lower()
116 def dns_transaction_udp(self, packet, host,
117 dump=False, timeout=None):
118 "send a DNS query and read the reply"
121 timeout = self.timeout
123 send_packet = ndr.ndr_pack(packet)
125 print(self.hexdump(send_packet))
126 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
127 s.settimeout(timeout)
128 s.connect((host, 53))
129 s.sendall(send_packet, 0)
130 recv_packet = s.recv(2048, 0)
132 print(self.hexdump(recv_packet))
133 response = ndr.ndr_unpack(dns.name_packet, recv_packet)
134 return (response, recv_packet)
139 def dns_transaction_tcp(self, packet, host,
140 dump=False, timeout=None):
141 "send a DNS query and read the reply, also return the raw packet"
144 timeout = self.timeout
146 send_packet = ndr.ndr_pack(packet)
148 print(self.hexdump(send_packet))
149 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
150 s.settimeout(timeout)
151 s.connect((host, 53))
152 tcp_packet = struct.pack('!H', len(send_packet))
153 tcp_packet += send_packet
154 s.sendall(tcp_packet)
156 recv_packet = s.recv(0xffff + 2, 0)
158 print(self.hexdump(recv_packet))
159 response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
165 # unpacking and packing again should produce same bytestream
166 my_packet = ndr.ndr_pack(response)
167 self.assertEquals(my_packet, recv_packet[2:])
168 return (response, recv_packet[2:])
170 def make_txt_update(self, prefix, txt_array, domain=None):
171 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
174 name = domain or self.get_dns_domain()
175 u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
177 self.finish_name_packet(p, updates)
181 r.name = "%s.%s" % (prefix, name)
182 r.rr_type = dns.DNS_QTYPE_TXT
183 r.rr_class = dns.DNS_QCLASS_IN
186 rdata = self.make_txt_record(txt_array)
189 p.nscount = len(updates)
194 def check_query_txt(self, prefix, txt_array, zone=None):
195 name = "%s.%s" % (prefix, zone or self.get_dns_domain())
196 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
199 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
202 self.finish_name_packet(p, questions)
203 (response, response_packet) =\
204 self.dns_transaction_udp(p, host=self.server_ip)
205 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
206 self.assertEquals(response.ancount, 1)
207 self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
210 class DNSTKeyTest(DNSTest):
212 super(DNSTKeyTest, self).setUp()
214 self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
215 self.settings["target_hostname"] = self.server
217 self.creds = credentials.Credentials()
218 self.creds.guess(self.lp_ctx)
219 self.creds.set_username(tests.env_get_var_value('USERNAME'))
220 self.creds.set_password(tests.env_get_var_value('PASSWORD'))
221 self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
222 self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
224 def tkey_trans(self, creds=None):
225 "Do a TKEY transaction and establish a gensec context"
230 self.key_name = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
232 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
233 q = self.make_name_question(self.key_name,
238 self.finish_name_packet(p, questions)
241 r.name = self.key_name
242 r.rr_type = dns.DNS_QTYPE_TKEY
243 r.rr_class = dns.DNS_QCLASS_IN
246 rdata = dns.tkey_record()
247 rdata.algorithm = "gss-tsig"
248 rdata.inception = int(time.time())
249 rdata.expiration = int(time.time()) + 60 * 60
250 rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
254 self.g = gensec.Security.start_client(self.settings)
255 self.g.set_credentials(creds)
256 self.g.set_target_service("dns")
257 self.g.set_target_hostname(self.server)
258 self.g.want_feature(gensec.FEATURE_SIGN)
259 self.g.start_mech_by_name("spnego")
262 client_to_server = ""
264 (finished, server_to_client) = self.g.update(client_to_server)
265 self.assertFalse(finished)
267 data = [x if isinstance(x, int) else ord(x) for x in list(server_to_client)]
268 rdata.key_data = data
269 rdata.key_size = len(data)
274 p.additional = additional
276 (response, response_packet) =\
277 self.dns_transaction_tcp(p, self.server_ip)
278 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
280 tkey_record = response.answers[0].rdata
281 server_to_client = binary_type(bytearray(tkey_record.key_data))
282 (finished, client_to_server) = self.g.update(server_to_client)
283 self.assertTrue(finished)
285 self.verify_packet(response, response_packet)
287 def verify_packet(self, response, response_packet, request_mac=b""):
288 self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
290 tsig_record = response.additional[0].rdata
291 mac = binary_type(bytearray(tsig_record.mac))
293 # Cut off tsig record from dns response packet for MAC verification
294 # and reset additional record count.
295 key_name_len = len(self.key_name) + 2
296 tsig_record_len = len(ndr.ndr_pack(tsig_record)) + key_name_len + 10
298 # convert str/bytes to a list (of string char or int)
299 # so it can be modified
300 response_packet_list = [x if isinstance(x, int) else ord(x) for x in response_packet]
301 del response_packet_list[-tsig_record_len:]
302 if isinstance(response_packet_list[11], int):
303 response_packet_list[11] = 0
305 response_packet_list[11] = chr(0)
307 # convert modified list (of string char or int) to str/bytes
308 response_packet_wo_tsig = binary_type(bytearray(response_packet_list))
310 fake_tsig = dns.fake_tsig_rec()
311 fake_tsig.name = self.key_name
312 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
314 fake_tsig.time_prefix = tsig_record.time_prefix
315 fake_tsig.time = tsig_record.time
316 fake_tsig.algorithm_name = tsig_record.algorithm_name
317 fake_tsig.fudge = tsig_record.fudge
319 fake_tsig.other_size = 0
320 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
322 data = request_mac + response_packet_wo_tsig + fake_tsig_packet
323 self.g.check_packet(data, data, mac)
325 def sign_packet(self, packet, key_name):
326 "Sign a packet, calculate a MAC and add TSIG record"
327 packet_data = ndr.ndr_pack(packet)
329 fake_tsig = dns.fake_tsig_rec()
330 fake_tsig.name = key_name
331 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
333 fake_tsig.time_prefix = 0
334 fake_tsig.time = int(time.time())
335 fake_tsig.algorithm_name = "gss-tsig"
336 fake_tsig.fudge = 300
338 fake_tsig.other_size = 0
339 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
341 data = packet_data + fake_tsig_packet
342 mac = self.g.sign_packet(data, data)
343 mac_list = [x if isinstance(x, int) else ord(x) for x in list(mac)]
345 rdata = dns.tsig_record()
346 rdata.algorithm_name = "gss-tsig"
347 rdata.time_prefix = 0
348 rdata.time = fake_tsig.time
350 rdata.original_id = packet.id
354 rdata.mac_size = len(mac_list)
358 r.rr_type = dns.DNS_QTYPE_TSIG
359 r.rr_class = dns.DNS_QCLASS_ANY
365 packet.additional = additional
370 def bad_sign_packet(self, packet, key_name):
371 '''Add bad signature for a packet by bitflipping
372 the final byte in the MAC'''
374 mac_list = [x if isinstance(x, int) else ord(x) for x in list("badmac")]
376 rdata = dns.tsig_record()
377 rdata.algorithm_name = "gss-tsig"
378 rdata.time_prefix = 0
379 rdata.time = int(time.time())
381 rdata.original_id = packet.id
385 rdata.mac_size = len(mac_list)
389 r.rr_type = dns.DNS_QTYPE_TSIG
390 r.rr_class = dns.DNS_QCLASS_ANY
396 packet.additional = additional
399 def search_record(self, name):
400 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
403 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
406 self.finish_name_packet(p, questions)
407 (response, response_packet) =\
408 self.dns_transaction_udp(p, self.server_ip)
409 return response.operation & 0x000F
411 def make_update_request(self, delete=False):
412 "Create a DNS update request"
414 rr_class = dns.DNS_QCLASS_IN
418 rr_class = dns.DNS_QCLASS_NONE
421 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
422 q = self.make_name_question(self.get_dns_domain(),
427 self.finish_name_packet(p, questions)
431 r.name = self.newrecname
432 r.rr_type = dns.DNS_QTYPE_TXT
433 r.rr_class = rr_class
436 rdata = self.make_txt_record(['"This is a test"'])
439 p.nscount = len(updates)