libndr: Avoid assigning duplicate versions to symbols
[amitay/samba.git] / selftest / target / dns_hub.py
1 #!/usr/bin/env python3
2 #
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Volker Lendecke 2017
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19 # Used by selftest to proxy DNS queries to the correct testenv DC.
20 # See selftest/target/README for more details.
21 # Based on the EchoServer example from python docs
22
23 import threading
24 import sys
25 import select
26 import socket
27 import collections
28 import time
29 from samba.dcerpc import dns
30 import samba.ndr as ndr
31
32 if sys.version_info[0] < 3:
33     import SocketServer
34     sserver = SocketServer
35 else:
36     import socketserver
37     sserver = socketserver
38
39 DNS_REQUEST_TIMEOUT = 10
40
41 # make sure the script dies immediately when hitting control-C,
42 # rather than raising KeyboardInterrupt. As we do all database
43 # operations using transactions, this is safe.
44 import signal
45 signal.signal(signal.SIGINT, signal.SIG_DFL)
46
47 class DnsHandler(sserver.BaseRequestHandler):
48     dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
49     def dns_qtype_string(self, qtype):
50         "Return a readable qtype code"
51         return self.dns_qtype_strings[qtype]
52
53     dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
54     def dns_rcode_string(self, rcode):
55         "Return a readable error code"
56         return self.dns_rcode_strings[rcode]
57
58     def dns_transaction_udp(self, packet, host):
59         "send a DNS query and read the reply"
60         s = None
61         flags = socket.AddressInfo.AI_NUMERICHOST
62         flags |= socket.AddressInfo.AI_NUMERICSERV
63         flags |= socket.AddressInfo.AI_PASSIVE
64         addr_info = socket.getaddrinfo(host, int(53),
65                                        type=socket.SocketKind.SOCK_DGRAM,
66                                        flags=flags)
67         assert len(addr_info) == 1
68         try:
69             send_packet = ndr.ndr_pack(packet)
70             s = socket.socket(addr_info[0][0], addr_info[0][1], 0)
71             s.settimeout(DNS_REQUEST_TIMEOUT)
72             s.connect(addr_info[0][4])
73             s.sendall(send_packet, 0)
74             recv_packet = s.recv(2048, 0)
75             return ndr.ndr_unpack(dns.name_packet, recv_packet)
76         except socket.error as err:
77             print("Error sending to host %s for name %s: %s\n" %
78                   (host, packet.questions[0].name, err.errno))
79             raise
80         finally:
81             if s is not None:
82                 s.close()
83         return None
84
85     def get_pdc_ipv4_addr(self, lookup_name):
86         """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
87
88         realm_to_ip_mappings = self.server.realm_to_ip_mappings
89
90         # sort the realms so we find the longest-match first
91         testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
92         testenv_realms.reverse()
93
94         for realm in testenv_realms:
95             if lookup_name.endswith(realm):
96                 # return the corresponding IP address for this realm's PDC
97                 return realm_to_ip_mappings[realm]
98
99         return None
100
101     def forwarder(self, name):
102         lname = name.lower()
103
104         # check for special cases used by tests (e.g. dns_forwarder.py)
105         if lname.endswith('an-address-that-will-not-resolve'):
106             return 'ignore'
107         if lname.endswith('dsfsdfs'):
108             return 'fail'
109         if lname.endswith("torture1", 0, len(lname)-2):
110             # CATCH TORTURE100, TORTURE101, ...
111             return 'torture'
112         if lname.endswith('_none_.example.com'):
113             return 'torture'
114         if lname.endswith('torturedom.samba.example.com'):
115             return 'torture'
116
117         # return the testenv PDC matching the realm being requested
118         return self.get_pdc_ipv4_addr(lname)
119
120     def handle(self):
121         start = time.monotonic()
122         data, sock = self.request
123         query = ndr.ndr_unpack(dns.name_packet, data)
124         name = query.questions[0].name
125         forwarder = self.forwarder(name)
126         response = None
127
128         if forwarder == 'ignore':
129             return
130         elif forwarder == 'fail':
131             pass
132         elif forwarder in ['torture', None]:
133             response = query
134             response.operation |= dns.DNS_FLAG_REPLY
135             response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
136             response.operation |= dns.DNS_RCODE_NXDOMAIN
137         else:
138             try:
139                 response = self.dns_transaction_udp(query, forwarder)
140             except OSError as err:
141                 print("dns_hub: Error sending dns query to forwarder[%s] for name[%s]: %s" %
142                       (forwarder, name, err))
143
144         if response is None:
145             response = query
146             response.operation |= dns.DNS_FLAG_REPLY
147             response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
148             response.operation |= dns.DNS_RCODE_SERVFAIL
149
150         send_packet = ndr.ndr_pack(response)
151
152         end = time.monotonic()
153         tdiff = end - start
154         errcode = response.operation & dns.DNS_RCODE
155         if tdiff > (DNS_REQUEST_TIMEOUT/5):
156             debug = True
157         else:
158             debug = False
159         if debug:
160             print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
161                 (forwarder, self.client_address, name,
162                  self.dns_qtype_string(query.questions[0].question_type),
163                  self.dns_rcode_string(errcode), response.operation, tdiff))
164
165         try:
166             sock.sendto(send_packet, self.client_address)
167         except socket.error as err:
168             print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
169                 (self.client_address, name, tdiff, err))
170
171
172 class server_thread(threading.Thread):
173     def __init__(self, server, name):
174         threading.Thread.__init__(self, name=name)
175         self.server = server
176
177     def run(self):
178         print("dns_hub[%s]: before serve_forever()" % self.name)
179         self.server.serve_forever()
180         print("dns_hub[%s]: after serve_forever()" % self.name)
181
182     def stop(self):
183         print("dns_hub[%s]: before shutdown()" % self.name)
184         self.server.shutdown()
185         print("dns_hub[%s]: after shutdown()" % self.name)
186
187 class UDPV4Server(sserver.UDPServer):
188     address_family = socket.AF_INET
189
190 class UDPV6Server(sserver.UDPServer):
191     address_family = socket.AF_INET6
192
193 def main():
194     if len(sys.argv) < 4:
195         print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
196         sys.exit(1)
197
198     timeout = int(sys.argv[1]) * 1000
199     timeout = min(timeout, 2**31 - 1)  # poll with 32-bit int can't take more
200     # we pass in the listen addresses as a comma-separated string.
201     listenaddresses = sys.argv[2].split(',')
202     # we pass in the realm-to-IP mappings as a comma-separated key=value
203     # string. Convert this back into a dictionary that the DnsHandler can use
204     realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(','))
205
206     def prepare_server_thread(listenaddress, realm_mappings):
207
208         flags = socket.AddressInfo.AI_NUMERICHOST
209         flags |= socket.AddressInfo.AI_NUMERICSERV
210         flags |= socket.AddressInfo.AI_PASSIVE
211         addr_info = socket.getaddrinfo(listenaddress, int(53),
212                                        type=socket.SocketKind.SOCK_DGRAM,
213                                        flags=flags)
214         assert len(addr_info) == 1
215         if addr_info[0][0] == socket.AddressFamily.AF_INET6:
216             server = UDPV6Server(addr_info[0][4], DnsHandler)
217         else:
218             server = UDPV4Server(addr_info[0][4], DnsHandler)
219
220         # we pass in the realm-to-IP mappings as a comma-separated key=value
221         # string. Convert this back into a dictionary that the DnsHandler can use
222         server.realm_to_ip_mappings = realm_mappings
223         t = server_thread(server, name="UDP[%s]" % listenaddress)
224         return t
225
226     print("dns_hub will proxy DNS requests for the following realms:")
227     for realm, ip in realm_mappings.items():
228         print("  {0} ==> {1}".format(realm, ip))
229
230     print("dns_hub will listen on the following UDP addresses:")
231     threads = []
232     for listenaddress in listenaddresses:
233         print("  %s" % listenaddress)
234         t = prepare_server_thread(listenaddress, realm_mappings)
235         threads.append(t)
236
237     for t in threads:
238         t.start()
239     p = select.poll()
240     stdin = sys.stdin.fileno()
241     p.register(stdin, select.POLLIN)
242     p.poll(timeout)
243     print("dns_hub: after poll()")
244     for t in threads:
245         t.stop()
246     for t in threads:
247         t.join()
248     print("dns_hub: before exit()")
249     sys.exit(0)
250
251 main()