pyldb: avoid segfault when adding an element with no name
[kai/samba-autobuild/.git] / selftest / target / dns_hub.py
index d2d1f39e752aee68e1cb221b18b1c475e833ef5c..49fbeff7b9921dd0e98f7a62fa6fce7045459b24 100755 (executable)
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 #
 # Unix SMB/CIFS implementation.
 # Copyright (C) Volker Lendecke 2017
@@ -24,6 +24,7 @@ import threading
 import sys
 import select
 import socket
+import time
 from samba.dcerpc import dns
 import samba.ndr as ndr
 
@@ -34,15 +35,27 @@ else:
     import socketserver
     sserver = socketserver
 
+DNS_REQUEST_TIMEOUT = 10
+
 
 class DnsHandler(sserver.BaseRequestHandler):
+    dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
+    def dns_qtype_string(self, qtype):
+        "Return a readable qtype code"
+        return self.dns_qtype_strings[qtype]
+
+    dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
+    def dns_rcode_string(self, rcode):
+        "Return a readable error code"
+        return self.dns_rcode_strings[rcode]
+
     def dns_transaction_udp(self, packet, host):
         "send a DNS query and read the reply"
         s = None
         try:
             send_packet = ndr.ndr_pack(packet)
             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-            s.settimeout(5)
+            s.settimeout(DNS_REQUEST_TIMEOUT)
             s.connect((host, 53))
             s.sendall(send_packet, 0)
             recv_packet = s.recv(2048, 0)
@@ -56,40 +69,43 @@ class DnsHandler(sserver.BaseRequestHandler):
                 s.close()
         return None
 
+    def get_pdc_ipv4_addr(self, lookup_name):
+        """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
+
+        realm_to_ip_mappings = self.server.realm_to_ip_mappings
+
+        # sort the realms so we find the longest-match first
+        testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
+        testenv_realms.reverse()
+
+        for realm in testenv_realms:
+            if lookup_name.endswith(realm):
+                # return the corresponding IP address for this realm's PDC
+                return realm_to_ip_mappings[realm]
+
+        return None
+
     def forwarder(self, name):
         lname = name.lower()
 
+        # check for special cases used by tests (e.g. dns_forwarder.py)
         if lname.endswith('an-address-that-will-not-resolve'):
             return 'ignore'
         if lname.endswith('dsfsdfs'):
             return 'fail'
-        if lname.endswith('adnonssdom.samba.example.com'):
-            return '127.0.0.17'
-        if lname.endswith('adnontlmdom.samba.example.com'):
-            return '127.0.0.18'
-        if lname.endswith('samba2000.example.com'):
-            return '127.0.0.25'
-        if lname.endswith('samba2003.example.com'):
-            return '127.0.0.26'
-        if lname.endswith('samba2008r2.example.com'):
-            return '127.0.0.27'
-        if lname.endswith('addom.samba.example.com'):
-            return '127.0.0.30'
-        if lname.endswith('sub.samba.example.com'):
-            return '127.0.0.31'
-        if lname.endswith('chgdcpassword.samba.example.com'):
-            return '127.0.0.32'
-        if lname.endswith('backupdom.samba.example.com'):
-            return '127.0.0.40'
-        if lname.endswith('renamedom.samba.example.com'):
-            return '127.0.0.42'
-        if lname.endswith('labdom.samba.example.com'):
-            return '127.0.0.43'
-        if lname.endswith('samba.example.com'):
-            return '127.0.0.21'
-        return None
+        if lname.endswith("torture1", 0, len(lname)-2):
+            # CATCH TORTURE100, TORTURE101, ...
+            return 'torture'
+        if lname.endswith('_none_.example.com'):
+            return 'torture'
+        if lname.endswith('torturedom.samba.example.com'):
+            return 'torture'
+
+        # return the testenv PDC matching the realm being requested
+        return self.get_pdc_ipv4_addr(lname)
 
     def handle(self):
+        start = time.monotonic()
         data, sock = self.request
         query = ndr.ndr_unpack(dns.name_packet, data)
         name = query.questions[0].name
@@ -100,13 +116,13 @@ class DnsHandler(sserver.BaseRequestHandler):
             return
         elif forwarder is 'fail':
             pass
-        elif forwarder is not None:
-            response = self.dns_transaction_udp(query, forwarder)
-        else:
+        elif forwarder in ['torture', None]:
             response = query
             response.operation |= dns.DNS_FLAG_REPLY
             response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
             response.operation |= dns.DNS_RCODE_NXDOMAIN
+        else:
+            response = self.dns_transaction_udp(query, forwarder)
 
         if response is None:
             response = query
@@ -116,14 +132,24 @@ class DnsHandler(sserver.BaseRequestHandler):
 
         send_packet = ndr.ndr_pack(response)
 
-        print("dns_hub: sending %s to address %s for name %s\n" %
-            (forwarder, self.client_address, name))
+        end = time.monotonic()
+        tdiff = end - start
+        errcode = response.operation & dns.DNS_RCODE
+        if tdiff > (DNS_REQUEST_TIMEOUT/5):
+            debug = True
+        else:
+            debug = False
+        if debug:
+            print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
+                (forwarder, self.client_address, name,
+                 self.dns_qtype_string(query.questions[0].question_type),
+                 self.dns_rcode_string(errcode), response.operation, tdiff))
 
         try:
             sock.sendto(send_packet, self.client_address)
         except socket.error as err:
-            print("Error sending %s to address %s for name %s: %s\n" %
-                (forwarder, self.client_address, name, err))
+            print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
+                (self.client_address, name, tdiff, err))
 
 
 class server_thread(threading.Thread):
@@ -140,7 +166,18 @@ def main():
     timeout = int(sys.argv[1]) * 1000
     timeout = min(timeout, 2**31 - 1)  # poll with 32-bit int can't take more
     host = sys.argv[2]
+
     server = sserver.UDPServer((host, int(53)), DnsHandler)
+
+    # we pass in the realm-to-IP mappings as a comma-separated key=value
+    # string. Convert this back into a dictionary that the DnsHandler can use
+    realm_mapping = dict(kv.split('=') for kv in sys.argv[3].split(','))
+    server.realm_to_ip_mappings = realm_mapping
+
+    print("dns_hub will proxy DNS requests for the following realms:")
+    for realm, ip in server.realm_to_ip_mappings.items():
+        print("  {0} ==> {1}".format(realm, ip))
+
     t = server_thread(server)
     t.start()
     p = select.poll()