dns_hub: Add some debug as to what DNS proxying is happening
[kai/samba-autobuild/.git] / selftest / target / dns_hub.py
1 #!/usr/bin/env python3
2 #
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Volker Lendecke 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 # Used by selftest to proxy DNS queries to the correct testenv DC.
20 # See selftest/target/README for more details.
21 # Based on the EchoServer example from python docs
22
23 import threading
24 import sys
25 import select
26 import socket
27 import time
28 from samba.dcerpc import dns
29 import samba.ndr as ndr
30
31 if sys.version_info[0] < 3:
32     import SocketServer
33     sserver = SocketServer
34 else:
35     import socketserver
36     sserver = socketserver
37
38 DNS_REQUEST_TIMEOUT = 10
39
40
41 class DnsHandler(sserver.BaseRequestHandler):
42     dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
43     def dns_qtype_string(self, qtype):
44         "Return a readable qtype code"
45         return self.dns_qtype_strings[qtype]
46
47     dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
48     def dns_rcode_string(self, rcode):
49         "Return a readable error code"
50         return self.dns_rcode_strings[rcode]
51
52     def dns_transaction_udp(self, packet, host):
53         "send a DNS query and read the reply"
54         s = None
55         try:
56             send_packet = ndr.ndr_pack(packet)
57             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
58             s.settimeout(DNS_REQUEST_TIMEOUT)
59             s.connect((host, 53))
60             s.sendall(send_packet, 0)
61             recv_packet = s.recv(2048, 0)
62             return ndr.ndr_unpack(dns.name_packet, recv_packet)
63         except socket.error as err:
64             print("Error sending to host %s for name %s: %s\n" %
65                   (host, packet.questions[0].name, err.errno))
66             raise
67         finally:
68             if s is not None:
69                 s.close()
70         return None
71
72     def get_pdc_ipv4_addr(self, lookup_name):
73         """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
74
75         realm_to_ip_mappings = self.server.realm_to_ip_mappings
76
77         # sort the realms so we find the longest-match first
78         testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
79         testenv_realms.reverse()
80
81         for realm in testenv_realms:
82             if lookup_name.endswith(realm):
83                 # return the corresponding IP address for this realm's PDC
84                 return realm_to_ip_mappings[realm]
85
86         return None
87
88     def forwarder(self, name):
89         lname = name.lower()
90
91         # check for special cases used by tests (e.g. dns_forwarder.py)
92         if lname.endswith('an-address-that-will-not-resolve'):
93             return 'ignore'
94         if lname.endswith('dsfsdfs'):
95             return 'fail'
96         if lname.endswith("torture1", 0, len(lname)-2):
97             # CATCH TORTURE100, TORTURE101, ...
98             return 'torture'
99         if lname.endswith('_none_.example.com'):
100             return 'torture'
101         if lname.endswith('torturedom.samba.example.com'):
102             return 'torture'
103
104         # return the testenv PDC matching the realm being requested
105         return self.get_pdc_ipv4_addr(lname)
106
107     def handle(self):
108         start = time.monotonic()
109         data, sock = self.request
110         query = ndr.ndr_unpack(dns.name_packet, data)
111         name = query.questions[0].name
112         forwarder = self.forwarder(name)
113         response = None
114
115         if forwarder is 'ignore':
116             return
117         elif forwarder is 'fail':
118             pass
119         elif forwarder in ['torture', None]:
120             response = query
121             response.operation |= dns.DNS_FLAG_REPLY
122             response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
123             response.operation |= dns.DNS_RCODE_NXDOMAIN
124         else:
125             response = self.dns_transaction_udp(query, forwarder)
126
127         if response is None:
128             response = query
129             response.operation |= dns.DNS_FLAG_REPLY
130             response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
131             response.operation |= dns.DNS_RCODE_SERVFAIL
132
133         send_packet = ndr.ndr_pack(response)
134
135         end = time.monotonic()
136         tdiff = end - start
137         errcode = response.operation & dns.DNS_RCODE
138         if tdiff > (DNS_REQUEST_TIMEOUT/5):
139             debug = True
140         else:
141             debug = False
142         if debug:
143             print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
144                 (forwarder, self.client_address, name,
145                  self.dns_qtype_string(query.questions[0].question_type),
146                  self.dns_rcode_string(errcode), response.operation, tdiff))
147
148         try:
149             sock.sendto(send_packet, self.client_address)
150         except socket.error as err:
151             print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
152                 (self.client_address, name, tdiff, err))
153
154
155 class server_thread(threading.Thread):
156     def __init__(self, server):
157         threading.Thread.__init__(self)
158         self.server = server
159
160     def run(self):
161         self.server.serve_forever()
162         print("dns_hub: after serve_forever()")
163
164
165 def main():
166     timeout = int(sys.argv[1]) * 1000
167     timeout = min(timeout, 2**31 - 1)  # poll with 32-bit int can't take more
168     host = sys.argv[2]
169
170     server = sserver.UDPServer((host, int(53)), DnsHandler)
171
172     # we pass in the realm-to-IP mappings as a comma-separated key=value
173     # string. Convert this back into a dictionary that the DnsHandler can use
174     realm_mapping = dict(kv.split('=') for kv in sys.argv[3].split(','))
175     server.realm_to_ip_mappings = realm_mapping
176
177     print("dns_hub will proxy DNS requests for the following realms:")
178     for realm, ip in server.realm_to_ip_mappings.items():
179         print("  {0} ==> {1}".format(realm, ip))
180
181     t = server_thread(server)
182     t.start()
183     p = select.poll()
184     stdin = sys.stdin.fileno()
185     p.register(stdin, select.POLLIN)
186     p.poll(timeout)
187     print("dns_hub: after poll()")
188     server.shutdown()
189     t.join()
190     print("dns_hub: before exit()")
191     sys.exit(0)
192
193 main()