3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Volker Lendecke 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 # Used by selftest to proxy DNS queries to the correct testenv DC.
20 # See selftest/target/README for more details.
21 # Based on the EchoServer example from python docs
29 from samba.dcerpc import dns
30 import samba.ndr as ndr
32 if sys.version_info[0] < 3:
34 sserver = SocketServer
37 sserver = socketserver
39 DNS_REQUEST_TIMEOUT = 10
41 # make sure the script dies immediately when hitting control-C,
42 # rather than raising KeyboardInterrupt. As we do all database
43 # operations using transactions, this is safe.
45 signal.signal(signal.SIGINT, signal.SIG_DFL)
47 class DnsHandler(sserver.BaseRequestHandler):
48 dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
49 def dns_qtype_string(self, qtype):
50 "Return a readable qtype code"
51 return self.dns_qtype_strings[qtype]
53 dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
54 def dns_rcode_string(self, rcode):
55 "Return a readable error code"
56 return self.dns_rcode_strings[rcode]
58 def dns_transaction_udp(self, packet, host):
59 "send a DNS query and read the reply"
61 flags = socket.AddressInfo.AI_NUMERICHOST
62 flags |= socket.AddressInfo.AI_NUMERICSERV
63 flags |= socket.AddressInfo.AI_PASSIVE
64 addr_info = socket.getaddrinfo(host, int(53),
65 type=socket.SocketKind.SOCK_DGRAM,
67 assert len(addr_info) == 1
69 send_packet = ndr.ndr_pack(packet)
70 s = socket.socket(addr_info[0][0], addr_info[0][1], 0)
71 s.settimeout(DNS_REQUEST_TIMEOUT)
72 s.connect(addr_info[0][4])
73 s.sendall(send_packet, 0)
74 recv_packet = s.recv(2048, 0)
75 return ndr.ndr_unpack(dns.name_packet, recv_packet)
76 except socket.error as err:
77 print("Error sending to host %s for name %s: %s\n" %
78 (host, packet.questions[0].name, err.errno))
85 def get_pdc_ipv4_addr(self, lookup_name):
86 """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
88 realm_to_ip_mappings = self.server.realm_to_ip_mappings
90 # sort the realms so we find the longest-match first
91 testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
92 testenv_realms.reverse()
94 for realm in testenv_realms:
95 if lookup_name.endswith(realm):
96 # return the corresponding IP address for this realm's PDC
97 return realm_to_ip_mappings[realm]
101 def forwarder(self, name):
104 # check for special cases used by tests (e.g. dns_forwarder.py)
105 if lname.endswith('an-address-that-will-not-resolve'):
107 if lname.endswith('dsfsdfs'):
109 if lname.endswith("torture1", 0, len(lname)-2):
110 # CATCH TORTURE100, TORTURE101, ...
112 if lname.endswith('_none_.example.com'):
114 if lname.endswith('torturedom.samba.example.com'):
117 # return the testenv PDC matching the realm being requested
118 return self.get_pdc_ipv4_addr(lname)
121 start = time.monotonic()
122 data, sock = self.request
123 query = ndr.ndr_unpack(dns.name_packet, data)
124 name = query.questions[0].name
125 forwarder = self.forwarder(name)
128 if forwarder == 'ignore':
130 elif forwarder == 'fail':
132 elif forwarder in ['torture', None]:
134 response.operation |= dns.DNS_FLAG_REPLY
135 response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
136 response.operation |= dns.DNS_RCODE_NXDOMAIN
139 response = self.dns_transaction_udp(query, forwarder)
140 except OSError as err:
141 print("dns_hub: Error sending dns query to forwarder[%s] for name[%s]: %s" %
142 (forwarder, name, err))
146 response.operation |= dns.DNS_FLAG_REPLY
147 response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
148 response.operation |= dns.DNS_RCODE_SERVFAIL
150 send_packet = ndr.ndr_pack(response)
152 end = time.monotonic()
154 errcode = response.operation & dns.DNS_RCODE
155 if tdiff > (DNS_REQUEST_TIMEOUT/5):
160 print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
161 (forwarder, self.client_address, name,
162 self.dns_qtype_string(query.questions[0].question_type),
163 self.dns_rcode_string(errcode), response.operation, tdiff))
166 sock.sendto(send_packet, self.client_address)
167 except socket.error as err:
168 print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
169 (self.client_address, name, tdiff, err))
172 class server_thread(threading.Thread):
173 def __init__(self, server, name):
174 threading.Thread.__init__(self, name=name)
178 print("dns_hub[%s]: before serve_forever()" % self.name)
179 self.server.serve_forever()
180 print("dns_hub[%s]: after serve_forever()" % self.name)
183 print("dns_hub[%s]: before shutdown()" % self.name)
184 self.server.shutdown()
185 print("dns_hub[%s]: after shutdown()" % self.name)
187 class UDPV4Server(sserver.UDPServer):
188 address_family = socket.AF_INET
190 class UDPV6Server(sserver.UDPServer):
191 address_family = socket.AF_INET6
194 if len(sys.argv) < 4:
195 print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
198 timeout = int(sys.argv[1]) * 1000
199 timeout = min(timeout, 2**31 - 1) # poll with 32-bit int can't take more
200 # we pass in the listen addresses as a comma-separated string.
201 listenaddresses = sys.argv[2].split(',')
202 # we pass in the realm-to-IP mappings as a comma-separated key=value
203 # string. Convert this back into a dictionary that the DnsHandler can use
204 realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(','))
206 def prepare_server_thread(listenaddress, realm_mappings):
208 flags = socket.AddressInfo.AI_NUMERICHOST
209 flags |= socket.AddressInfo.AI_NUMERICSERV
210 flags |= socket.AddressInfo.AI_PASSIVE
211 addr_info = socket.getaddrinfo(listenaddress, int(53),
212 type=socket.SocketKind.SOCK_DGRAM,
214 assert len(addr_info) == 1
215 if addr_info[0][0] == socket.AddressFamily.AF_INET6:
216 server = UDPV6Server(addr_info[0][4], DnsHandler)
218 server = UDPV4Server(addr_info[0][4], DnsHandler)
220 # we pass in the realm-to-IP mappings as a comma-separated key=value
221 # string. Convert this back into a dictionary that the DnsHandler can use
222 server.realm_to_ip_mappings = realm_mappings
223 t = server_thread(server, name="UDP[%s]" % listenaddress)
226 print("dns_hub will proxy DNS requests for the following realms:")
227 for realm, ip in realm_mappings.items():
228 print(" {0} ==> {1}".format(realm, ip))
230 print("dns_hub will listen on the following UDP addresses:")
232 for listenaddress in listenaddresses:
233 print(" %s" % listenaddress)
234 t = prepare_server_thread(listenaddress, realm_mappings)
240 stdin = sys.stdin.fileno()
241 p.register(stdin, select.POLLIN)
243 print("dns_hub: after poll()")
248 print("dns_hub: before exit()")