dnspython: Update to latest upstream snapshot.
[samba.git] / lib / dnspython / dns / resolver.py
index 30977f3a8bbb98164be057ce1952d426620317ce..90f95e8ed0504f42bf86f66c55e488149ef4c8ff 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2007, 2009, 2010 Nominum, Inc.
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
 #
 # Permission to use, copy, modify, and distribute this software and its
 # documentation for any purpose with or without fee is hereby granted,
@@ -23,12 +23,15 @@ import sys
 import time
 
 import dns.exception
+import dns.ipv4
+import dns.ipv6
 import dns.message
 import dns.name
 import dns.query
 import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
+import dns.reversename
 
 if sys.platform == 'win32':
     import _winreg
@@ -93,8 +96,11 @@ class Answer(object):
     @type rrset: dns.rrset.RRset object
     @ivar expiration: The time when the answer expires
     @type expiration: float (seconds since the epoch)
+    @ivar canonical_name: The canonical name of the query name
+    @type canonical_name: dns.name.Name object
     """
-    def __init__(self, qname, rdtype, rdclass, response):
+    def __init__(self, qname, rdtype, rdclass, response,
+                 raise_on_no_answer=True):
         self.qname = qname
         self.rdtype = rdtype
         self.rdclass = rdclass
@@ -122,11 +128,31 @@ class Answer(object):
                             break
                         continue
                     except KeyError:
-                        raise NoAnswer
-                raise NoAnswer
-        if rrset is None:
+                        if raise_on_no_answer:
+                            raise NoAnswer
+                if raise_on_no_answer:
+                    raise NoAnswer
+        if rrset is None and raise_on_no_answer:
             raise NoAnswer
+        self.canonical_name = qname
         self.rrset = rrset
+        if rrset is None:
+            while 1:
+                # Look for a SOA RR whose owner name is a superdomain
+                # of qname.
+                try:
+                    srrset = response.find_rrset(response.authority, qname,
+                                                rdclass, dns.rdatatype.SOA)
+                    if min_ttl == -1 or srrset.ttl < min_ttl:
+                        min_ttl = srrset.ttl
+                    if srrset[0].minimum < min_ttl:
+                        min_ttl = srrset[0].minimum
+                    break
+                except KeyError:
+                    try:
+                        qname = qname.parent()
+                    except dns.name.NoParent:
+                        break
         self.expiration = time.time() + min_ttl
 
     def __getattr__(self, attr):
@@ -244,6 +270,127 @@ class Cache(object):
             self.data = {}
             self.next_cleaning = time.time() + self.cleaning_interval
 
+class LRUCacheNode(object):
+    """LRUCache node.
+    """
+    def __init__(self, key, value):
+        self.key = key
+        self.value = value
+        self.prev = self
+        self.next = self
+
+    def link_before(self, node):
+        self.prev = node.prev
+        self.next = node
+        node.prev.next = self
+        node.prev = self
+
+    def link_after(self, node):
+        self.prev = node
+        self.next = node.next
+        node.next.prev = self
+        node.next = self
+
+    def unlink(self):
+        self.next.prev = self.prev
+        self.prev.next = self.next
+
+class LRUCache(object):
+    """Bounded least-recently-used DNS answer cache.
+
+    This cache is better than the simple cache (above) if you're
+    running a web crawler or other process that does a lot of
+    resolutions.  The LRUCache has a maximum number of nodes, and when
+    it is full, the least-recently used node is removed to make space
+    for a new one.
+
+    @ivar data: A dictionary of cached data
+    @type data: dict
+    @ivar sentinel: sentinel node for circular doubly linked list of nodes
+    @type sentinel: LRUCacheNode object
+    @ivar max_size: The maximum number of nodes
+    @type max_size: int
+    """
+
+    def __init__(self, max_size=100000):
+        """Initialize a DNS cache.
+
+        @param max_size: The maximum number of nodes to cache; the default is
+        100000.  Must be > 1.
+        @type max_size: int
+        """
+        self.data = {}
+        self.set_max_size(max_size)
+        self.sentinel = LRUCacheNode(None, None)
+
+    def set_max_size(self, max_size):
+        if max_size < 1:
+            max_size = 1
+        self.max_size = max_size
+
+    def get(self, key):
+        """Get the answer associated with I{key}.  Returns None if
+        no answer is cached for the key.
+        @param key: the key
+        @type key: (dns.name.Name, int, int) tuple whose values are the
+        query name, rdtype, and rdclass.
+        @rtype: dns.resolver.Answer object or None
+        """
+        node = self.data.get(key)
+        if node is None:
+            return None
+        # Unlink because we're either going to move the node to the front
+        # of the LRU list or we're going to free it.
+        node.unlink()
+        if node.value.expiration <= time.time():
+            del self.data[node.key]
+            return None
+        node.link_after(self.sentinel)
+        return node.value
+
+    def put(self, key, value):
+        """Associate key and value in the cache.
+        @param key: the key
+        @type key: (dns.name.Name, int, int) tuple whose values are the
+        query name, rdtype, and rdclass.
+        @param value: The answer being cached
+        @type value: dns.resolver.Answer object
+        """
+        node = self.data.get(key)
+        if not node is None:
+            node.unlink()
+            del self.data[node.key]
+        while len(self.data) >= self.max_size:
+            node = self.sentinel.prev
+            node.unlink()
+            del self.data[node.key]
+        node = LRUCacheNode(key, value)
+        node.link_after(self.sentinel)
+        self.data[key] = node
+
+    def flush(self, key=None):
+        """Flush the cache.
+
+        If I{key} is specified, only that item is flushed.  Otherwise
+        the entire cache is flushed.
+
+        @param key: the key to flush
+        @type key: (dns.name.Name, int, int) tuple or None
+        """
+        if not key is None:
+            node = self.data.get(key)
+            if not node is None:
+                node.unlink()
+                del self.data[node.key]
+        else:
+            node = self.sentinel.next
+            while node != self.sentinel:
+                next = node.next
+                node.prev = None
+                node.next = None
+                node = next
+            self.data = {}
+
 class Resolver(object):
     """DNS stub resolver
 
@@ -546,7 +693,7 @@ class Resolver(object):
         return min(self.lifetime - duration, self.timeout)
 
     def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-              tcp=False, source=None):
+              tcp=False, source=None, raise_on_no_answer=True):
         """Query nameservers to find the answer to the question.
 
         The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects
@@ -564,10 +711,14 @@ class Resolver(object):
         @type tcp: bool
         @param source: bind to this IP address (defaults to machine default IP).
         @type source: IP address in dotted quad notation
+        @param raise_on_no_answer: raise NoAnswer if there's no answer
+        (defaults is True).
+        @type raise_on_no_answer: bool
         @rtype: dns.resolver.Answer instance
         @raises Timeout: no answers could be found in the specified lifetime
         @raises NXDOMAIN: the query name does not exist
-        @raises NoAnswer: the response did not contain an answer
+        @raises NoAnswer: the response did not contain an answer and
+        raise_on_no_answer is True.
         @raises NoNameservers: no non-broken nameservers are available to
         answer the question."""
 
@@ -597,8 +748,11 @@ class Resolver(object):
         for qname in qnames_to_try:
             if self.cache:
                 answer = self.cache.get((qname, rdtype, rdclass))
-                if answer:
-                    return answer
+                if not answer is None:
+                    if answer.rrset is None and raise_on_no_answer:
+                        raise NoAnswer
+                    else:
+                        return answer
             request = dns.message.make_query(qname, rdtype, rdclass)
             if not self.keyname is None:
                 request.use_tsig(self.keyring, self.keyname,
@@ -678,7 +832,8 @@ class Resolver(object):
             break
         if all_nxdomain:
             raise NXDOMAIN
-        answer = Answer(qname, rdtype, rdclass, response)
+        answer = Answer(qname, rdtype, rdclass, response,
+                        raise_on_no_answer)
         if self.cache:
             self.cache.put((qname, rdtype, rdclass), answer)
         return answer
@@ -731,14 +886,15 @@ def get_default_resolver():
     return default_resolver
 
 def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
-          tcp=False, source=None):
+          tcp=False, source=None, raise_on_no_answer=True):
     """Query nameservers to find the answer to the question.
 
     This is a convenience function that uses the default resolver
     object to make the query.
     @see: L{dns.resolver.Resolver.query} for more information on the
     parameters."""
-    return get_default_resolver().query(qname, rdtype, rdclass, tcp, source)
+    return get_default_resolver().query(qname, rdtype, rdclass, tcp, source,
+                                        raise_on_no_answer)
 
 def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
     """Find the name of the zone which contains the specified name.
@@ -771,3 +927,235 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
             name = name.parent()
         except dns.name.NoParent:
             raise NoRootSOA
+
+#
+# Support for overriding the system resolver for all python code in the
+# running process.
+#
+
+_protocols_for_socktype = {
+    socket.SOCK_DGRAM : [socket.SOL_UDP],
+    socket.SOCK_STREAM : [socket.SOL_TCP],
+    }
+
+_resolver = None
+_original_getaddrinfo = socket.getaddrinfo
+_original_getnameinfo = socket.getnameinfo
+_original_getfqdn = socket.getfqdn
+_original_gethostbyname = socket.gethostbyname
+_original_gethostbyname_ex = socket.gethostbyname_ex
+_original_gethostbyaddr = socket.gethostbyaddr
+
+def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
+                 proto=0, flags=0):
+    if flags & (socket.AI_ADDRCONFIG|socket.AI_V4MAPPED) != 0:
+        raise NotImplementedError
+    if host is None and service is None:
+        raise socket.gaierror(socket.EAI_NONAME)
+    v6addrs = []
+    v4addrs = []
+    canonical_name = None
+    try:
+        # Is host None or a V6 address literal?
+        if host is None:
+            canonical_name = 'localhost'
+            if flags & socket.AI_PASSIVE != 0:
+                v6addrs.append('::')
+                v4addrs.append('0.0.0.0')
+            else:
+                v6addrs.append('::1')
+                v4addrs.append('127.0.0.1')
+        else:
+            parts = host.split('%')
+            if len(parts) == 2:
+                ahost = parts[0]
+            else:
+                ahost = host
+            addr = dns.ipv6.inet_aton(ahost)
+            v6addrs.append(host)
+            canonical_name = host
+    except:
+        try:
+            # Is it a V4 address literal?
+            addr = dns.ipv4.inet_aton(host)
+            v4addrs.append(host)
+            canonical_name = host
+        except:
+            if flags & socket.AI_NUMERICHOST == 0:
+                try:
+                    qname = None
+                    if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+                        v6 = _resolver.query(host, dns.rdatatype.AAAA,
+                                             raise_on_no_answer=False)
+                        # Note that setting host ensures we query the same name
+                        # for A as we did for AAAA.
+                        host = v6.qname
+                        canonical_name = v6.canonical_name.to_text(True)
+                        if v6.rrset is not None:
+                            for rdata in v6.rrset:
+                                v6addrs.append(rdata.address)
+                    if family == socket.AF_INET or family == socket.AF_UNSPEC:
+                        v4 = _resolver.query(host, dns.rdatatype.A,
+                                             raise_on_no_answer=False)
+                        host = v4.qname
+                        canonical_name = v4.canonical_name.to_text(True)
+                        if v4.rrset is not None:
+                            for rdata in v4.rrset:
+                                v4addrs.append(rdata.address)
+                except dns.resolver.NXDOMAIN:
+                    raise socket.gaierror(socket.EAI_NONAME)
+                except:
+                    raise socket.gaierror(socket.EAI_SYSTEM)
+    port = None
+    try:
+        # Is it a port literal?
+        if service is None:
+            port = 0
+        else:
+            port = int(service)
+    except:
+        if flags & socket.AI_NUMERICSERV == 0:
+            try:
+                port = socket.getservbyname(service)
+            except:
+                pass
+    if port is None:
+        raise socket.gaierror(socket.EAI_NONAME)
+    tuples = []
+    if socktype == 0:
+        socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM]
+    else:
+        socktypes = [socktype]
+    if flags & socket.AI_CANONNAME != 0:
+        cname = canonical_name
+    else:
+        cname = ''
+    if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+        for addr in v6addrs:
+            for socktype in socktypes:
+                for proto in _protocols_for_socktype[socktype]:
+                    tuples.append((socket.AF_INET6, socktype, proto,
+                                   cname, (addr, port, 0, 0)))
+    if family == socket.AF_INET or family == socket.AF_UNSPEC:
+        for addr in v4addrs:
+            for socktype in socktypes:
+                for proto in _protocols_for_socktype[socktype]:
+                    tuples.append((socket.AF_INET, socktype, proto,
+                                   cname, (addr, port)))
+    if len(tuples) == 0:
+        raise socket.gaierror(socket.EAI_NONAME)
+    return tuples
+
+def _getnameinfo(sockaddr, flags=0):
+    host = sockaddr[0]
+    port = sockaddr[1]
+    if len(sockaddr) == 4:
+        scope = sockaddr[3]
+        family = socket.AF_INET6
+    else:
+        scope = None
+        family = socket.AF_INET
+    tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM,
+                          socket.SOL_TCP, 0)
+    if len(tuples) > 1:
+        raise socket.error('sockaddr resolved to multiple addresses')
+    addr = tuples[0][4][0]
+    if flags & socket.NI_DGRAM:
+        pname = 'udp'
+    else:
+        pname = 'tcp'
+    qname = dns.reversename.from_address(addr)
+    if flags & socket.NI_NUMERICHOST == 0:
+        try:
+            answer = _resolver.query(qname, 'PTR')
+            hostname = answer.rrset[0].target.to_text(True)
+        except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+            if flags & socket.NI_NAMEREQD:
+                raise socket.gaierror(socket.EAI_NONAME)
+            hostname = addr
+            if scope is not None:
+                hostname += '%' + str(scope)
+    else:
+        hostname = addr
+        if scope is not None:
+            hostname += '%' + str(scope)
+    if flags & socket.NI_NUMERICSERV:
+        service = str(port)
+    else:
+        service = socket.getservbyport(port, pname)
+    return (hostname, service)
+
+def _getfqdn(name=None):
+    if name is None:
+        name = socket.gethostname()
+    return _getnameinfo(_getaddrinfo(name, 80)[0][4])[0]
+
+def _gethostbyname(name):
+    return _gethostbyname_ex(name)[2][0]
+
+def _gethostbyname_ex(name):
+    aliases = []
+    addresses = []
+    tuples = _getaddrinfo(name, 0, socket.AF_INET, socket.SOCK_STREAM,
+                         socket.SOL_TCP, socket.AI_CANONNAME)
+    canonical = tuples[0][3]
+    for item in tuples:
+        addresses.append(item[4][0])
+    # XXX we just ignore aliases
+    return (canonical, aliases, addresses)
+
+def _gethostbyaddr(ip):
+    try:
+        addr = dns.ipv6.inet_aton(ip)
+        sockaddr = (ip, 80, 0, 0)
+        family = socket.AF_INET6
+    except:
+        sockaddr = (ip, 80)
+        family = socket.AF_INET
+    (name, port) = _getnameinfo(sockaddr, socket.NI_NAMEREQD)
+    aliases = []
+    addresses = []
+    tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP,
+                          socket.AI_CANONNAME)
+    canonical = tuples[0][3]
+    for item in tuples:
+        addresses.append(item[4][0])
+    # XXX we just ignore aliases
+    return (canonical, aliases, addresses)
+
+def override_system_resolver(resolver=None):
+    """Override the system resolver routines in the socket module with
+    versions which use dnspython's resolver.
+
+    This can be useful in testing situations where you want to control
+    the resolution behavior of python code without having to change
+    the system's resolver settings (e.g. /etc/resolv.conf).
+
+    The resolver to use may be specified; if it's not, the default
+    resolver will be used.
+
+    @param resolver: the resolver to use
+    @type resolver: dns.resolver.Resolver object or None
+    """
+    if resolver is None:
+        resolver = get_default_resolver()
+    global _resolver
+    _resolver = resolver
+    socket.getaddrinfo = _getaddrinfo
+    socket.getnameinfo = _getnameinfo
+    socket.getfqdn = _getfqdn
+    socket.gethostbyname = _gethostbyname
+    socket.gethostbyname_ex = _gethostbyname_ex
+    socket.gethostbyaddr = _gethostbyaddr
+
+def restore_system_resolver():
+    """Undo the effects of override_system_resolver().
+    """
+    global _resolver
+    _resolver = None
+    socket.getaddrinfo = _original_getaddrinfo
+    socket.getnameinfo = _original_getnameinfo
+    socket.getfqdn = _original_getfqdn
+    socket.gethostbyname = _original_gethostbyname
+    socket.gethostbyname_ex = _original_gethostbyname_ex
+    socket.gethostbyaddr = _original_gethostbyaddr