1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme <slow@samba.org> 2016
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from samba.tests import TestCaseInTempDir
20 from samba.dcerpc import dns, dnsp
21 from samba import gensec, tests
22 from samba import credentials
24 import samba.ndr as ndr
31 class DNSTest(TestCaseInTempDir):
34 super(DNSTest, self).setUp()
37 def errstr(self, errcode):
38 "Return a readable error code"
60 return string_codes[errcode]
62 def assert_rcode_equals(self, rcode, expected):
63 "Helper function to check return code"
64 self.assertEqual(rcode, expected, "Expected RCODE %s, got %s" %
65 (self.errstr(expected), self.errstr(rcode)))
67 def assert_dns_rcode_equals(self, packet, rcode):
68 "Helper function to check return code"
69 p_errcode = packet.operation & dns.DNS_RCODE
70 self.assertEqual(p_errcode, rcode, "Expected RCODE %s, got %s" %
71 (self.errstr(rcode), self.errstr(p_errcode)))
73 def assert_dns_opcode_equals(self, packet, opcode):
74 "Helper function to check opcode"
75 p_opcode = packet.operation & dns.DNS_OPCODE
76 self.assertEqual(p_opcode, opcode, "Expected OPCODE %s, got %s" %
79 def make_name_packet(self, opcode, qid=None):
80 "Helper creating a dns.name_packet"
83 p.id = random.randint(0x0, 0xff00)
89 def finish_name_packet(self, packet, questions):
90 "Helper to finalize a dns.name_packet"
91 packet.qdcount = len(questions)
92 packet.questions = questions
94 def make_name_question(self, name, qtype, qclass):
95 "Helper creating a dns.name_question"
96 q = dns.name_question()
98 q.question_type = qtype
99 q.question_class = qclass
102 def make_txt_record(self, records):
103 rdata_txt = dns.txt_record()
104 s_list = dnsp.string_list()
105 s_list.count = len(records)
107 rdata_txt.txt = s_list
110 def get_dns_domain(self):
111 "Helper to get dns domain"
112 return self.creds.get_realm().lower()
114 def dns_transaction_udp(self, packet, host,
115 dump=False, timeout=None):
116 "send a DNS query and read the reply"
119 timeout = self.timeout
121 send_packet = ndr.ndr_pack(packet)
123 print(self.hexdump(send_packet))
124 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
125 s.settimeout(timeout)
126 s.connect((host, 53))
127 s.sendall(send_packet, 0)
128 recv_packet = s.recv(2048, 0)
130 print(self.hexdump(recv_packet))
131 response = ndr.ndr_unpack(dns.name_packet, recv_packet)
132 return (response, recv_packet)
137 def dns_transaction_tcp(self, packet, host,
138 dump=False, timeout=None):
139 "send a DNS query and read the reply, also return the raw packet"
142 timeout = self.timeout
144 send_packet = ndr.ndr_pack(packet)
146 print(self.hexdump(send_packet))
147 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
148 s.settimeout(timeout)
149 s.connect((host, 53))
150 tcp_packet = struct.pack('!H', len(send_packet))
151 tcp_packet += send_packet
152 s.sendall(tcp_packet)
154 recv_packet = s.recv(0xffff + 2, 0)
156 print(self.hexdump(recv_packet))
157 response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
163 # unpacking and packing again should produce same bytestream
164 my_packet = ndr.ndr_pack(response)
165 self.assertEqual(my_packet, recv_packet[2:])
166 return (response, recv_packet[2:])
168 def make_txt_update(self, prefix, txt_array, zone=None, ttl=900):
169 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
172 name = zone or self.get_dns_domain()
173 u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
175 self.finish_name_packet(p, updates)
179 r.name = "%s.%s" % (prefix, name)
180 r.rr_type = dns.DNS_QTYPE_TXT
181 r.rr_class = dns.DNS_QCLASS_IN
184 rdata = self.make_txt_record(txt_array)
187 p.nscount = len(updates)
192 def check_query_txt(self, prefix, txt_array, zone=None):
193 name = "%s.%s" % (prefix, zone or self.get_dns_domain())
194 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
197 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
200 self.finish_name_packet(p, questions)
201 (response, response_packet) =\
202 self.dns_transaction_udp(p, host=self.server_ip)
203 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
204 self.assertEqual(response.ancount, 1)
205 self.assertEqual(response.answers[0].rdata.txt.str, txt_array)
208 class DNSTKeyTest(DNSTest):
210 super(DNSTKeyTest, self).setUp()
212 self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
213 self.settings["target_hostname"] = self.server
215 self.creds = credentials.Credentials()
216 self.creds.guess(self.lp_ctx)
217 self.creds.set_username(tests.env_get_var_value('USERNAME'))
218 self.creds.set_password(tests.env_get_var_value('PASSWORD'))
219 self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
220 self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
222 def tkey_trans(self, creds=None):
223 "Do a TKEY transaction and establish a gensec context"
228 self.key_name = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
230 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
231 q = self.make_name_question(self.key_name,
236 self.finish_name_packet(p, questions)
239 r.name = self.key_name
240 r.rr_type = dns.DNS_QTYPE_TKEY
241 r.rr_class = dns.DNS_QCLASS_IN
244 rdata = dns.tkey_record()
245 rdata.algorithm = "gss-tsig"
246 rdata.inception = int(time.time())
247 rdata.expiration = int(time.time()) + 60 * 60
248 rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
252 self.g = gensec.Security.start_client(self.settings)
253 self.g.set_credentials(creds)
254 self.g.set_target_service("dns")
255 self.g.set_target_hostname(self.server)
256 self.g.want_feature(gensec.FEATURE_SIGN)
257 self.g.start_mech_by_name("spnego")
260 client_to_server = b""
262 (finished, server_to_client) = self.g.update(client_to_server)
263 self.assertFalse(finished)
265 data = [x if isinstance(x, int) else ord(x) for x in list(server_to_client)]
266 rdata.key_data = data
267 rdata.key_size = len(data)
272 p.additional = additional
274 (response, response_packet) =\
275 self.dns_transaction_tcp(p, self.server_ip)
276 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
278 tkey_record = response.answers[0].rdata
279 server_to_client = bytes(tkey_record.key_data)
280 (finished, client_to_server) = self.g.update(server_to_client)
281 self.assertTrue(finished)
283 self.verify_packet(response, response_packet)
285 def verify_packet(self, response, response_packet, request_mac=b""):
286 self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
288 tsig_record = response.additional[0].rdata
289 mac = bytes(tsig_record.mac)
291 # Cut off tsig record from dns response packet for MAC verification
292 # and reset additional record count.
293 key_name_len = len(self.key_name) + 2
294 tsig_record_len = len(ndr.ndr_pack(tsig_record)) + key_name_len + 10
296 # convert str/bytes to a list (of string char or int)
297 # so it can be modified
298 response_packet_list = [x if isinstance(x, int) else ord(x) for x in response_packet]
299 del response_packet_list[-tsig_record_len:]
300 response_packet_list[11] = 0
302 # convert modified list (of string char or int) to str/bytes
303 response_packet_wo_tsig = bytes(response_packet_list)
305 fake_tsig = dns.fake_tsig_rec()
306 fake_tsig.name = self.key_name
307 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
309 fake_tsig.time_prefix = tsig_record.time_prefix
310 fake_tsig.time = tsig_record.time
311 fake_tsig.algorithm_name = tsig_record.algorithm_name
312 fake_tsig.fudge = tsig_record.fudge
314 fake_tsig.other_size = 0
315 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
317 data = request_mac + response_packet_wo_tsig + fake_tsig_packet
318 self.g.check_packet(data, data, mac)
320 def sign_packet(self, packet, key_name):
321 "Sign a packet, calculate a MAC and add TSIG record"
322 packet_data = ndr.ndr_pack(packet)
324 fake_tsig = dns.fake_tsig_rec()
325 fake_tsig.name = key_name
326 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
328 fake_tsig.time_prefix = 0
329 fake_tsig.time = int(time.time())
330 fake_tsig.algorithm_name = "gss-tsig"
331 fake_tsig.fudge = 300
333 fake_tsig.other_size = 0
334 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
336 data = packet_data + fake_tsig_packet
337 mac = self.g.sign_packet(data, data)
338 mac_list = [x if isinstance(x, int) else ord(x) for x in list(mac)]
340 rdata = dns.tsig_record()
341 rdata.algorithm_name = "gss-tsig"
342 rdata.time_prefix = 0
343 rdata.time = fake_tsig.time
345 rdata.original_id = packet.id
349 rdata.mac_size = len(mac_list)
353 r.rr_type = dns.DNS_QTYPE_TSIG
354 r.rr_class = dns.DNS_QCLASS_ANY
360 packet.additional = additional
365 def bad_sign_packet(self, packet, key_name):
366 """Add bad signature for a packet by bitflipping
367 the final byte in the MAC"""
369 mac_list = [x if isinstance(x, int) else ord(x) for x in list("badmac")]
371 rdata = dns.tsig_record()
372 rdata.algorithm_name = "gss-tsig"
373 rdata.time_prefix = 0
374 rdata.time = int(time.time())
376 rdata.original_id = packet.id
380 rdata.mac_size = len(mac_list)
384 r.rr_type = dns.DNS_QTYPE_TSIG
385 r.rr_class = dns.DNS_QCLASS_ANY
391 packet.additional = additional
394 def search_record(self, name):
395 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
398 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
401 self.finish_name_packet(p, questions)
402 (response, response_packet) =\
403 self.dns_transaction_udp(p, self.server_ip)
404 return response.operation & 0x000F
406 def make_update_request(self, delete=False):
407 "Create a DNS update request"
409 rr_class = dns.DNS_QCLASS_IN
413 rr_class = dns.DNS_QCLASS_NONE
416 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
417 q = self.make_name_question(self.get_dns_domain(),
422 self.finish_name_packet(p, questions)
426 r.name = self.newrecname
427 r.rr_type = dns.DNS_QTYPE_TXT
428 r.rr_class = rr_class
431 rdata = self.make_txt_record(['"This is a test"'])
434 p.nscount = len(updates)