Add samba.ensure_third_party_module() function, loading external python modules from...
[bbaumbach/samba-autobuild/.git] / lib / dnspython / dns / query.py
1 # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
2 #
3 # Permission to use, copy, modify, and distribute this software and its
4 # documentation for any purpose with or without fee is hereby granted,
5 # provided that the above copyright notice and this permission notice
6 # appear in all copies.
7 #
8 # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16 """Talk to a DNS server."""
17
18 from __future__ import generators
19
20 import errno
21 import select
22 import socket
23 import struct
24 import sys
25 import time
26
27 import dns.exception
28 import dns.inet
29 import dns.name
30 import dns.message
31 import dns.rdataclass
32 import dns.rdatatype
33
34 class UnexpectedSource(dns.exception.DNSException):
35     """Raised if a query response comes from an unexpected address or port."""
36     pass
37
38 class BadResponse(dns.exception.FormError):
39     """Raised if a query response does not respond to the question asked."""
40     pass
41
42 def _compute_expiration(timeout):
43     if timeout is None:
44         return None
45     else:
46         return time.time() + timeout
47
48 def _poll_for(fd, readable, writable, error, timeout):
49     """
50     @param fd: File descriptor (int).
51     @param readable: Whether to wait for readability (bool).
52     @param writable: Whether to wait for writability (bool).
53     @param expiration: Deadline timeout (expiration time, in seconds (float)).
54
55     @return True on success, False on timeout
56     """
57     event_mask = 0
58     if readable:
59         event_mask |= select.POLLIN
60     if writable:
61         event_mask |= select.POLLOUT
62     if error:
63         event_mask |= select.POLLERR
64
65     pollable = select.poll()
66     pollable.register(fd, event_mask)
67
68     if timeout:
69         event_list = pollable.poll(long(timeout * 1000))
70     else:
71         event_list = pollable.poll()
72
73     return bool(event_list)
74
75 def _select_for(fd, readable, writable, error, timeout):
76     """
77     @param fd: File descriptor (int).
78     @param readable: Whether to wait for readability (bool).
79     @param writable: Whether to wait for writability (bool).
80     @param expiration: Deadline timeout (expiration time, in seconds (float)).
81
82     @return True on success, False on timeout
83     """
84     rset, wset, xset = [], [], []
85
86     if readable:
87         rset = [fd]
88     if writable:
89         wset = [fd]
90     if error:
91         xset = [fd]
92
93     if timeout is None:
94         (rcount, wcount, xcount) = select.select(rset, wset, xset)
95     else:
96         (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
97
98     return bool((rcount or wcount or xcount))
99
100 def _wait_for(fd, readable, writable, error, expiration):
101     done = False
102     while not done:
103         if expiration is None:
104             timeout = None
105         else:
106             timeout = expiration - time.time()
107             if timeout <= 0.0:
108                 raise dns.exception.Timeout
109         try:
110             if not _polling_backend(fd, readable, writable, error, timeout):
111                 raise dns.exception.Timeout
112         except select.error, e:
113             if e.args[0] != errno.EINTR:
114                 raise e
115         done = True
116
117 def _set_polling_backend(fn):
118     """
119     Internal API. Do not use.
120     """
121     global _polling_backend
122
123     _polling_backend = fn
124
125 if hasattr(select, 'poll'):
126     # Prefer poll() on platforms that support it because it has no
127     # limits on the maximum value of a file descriptor (plus it will
128     # be more efficient for high values).
129     _polling_backend = _poll_for
130 else:
131     _polling_backend = _select_for
132
133 def _wait_for_readable(s, expiration):
134     _wait_for(s, True, False, True, expiration)
135
136 def _wait_for_writable(s, expiration):
137     _wait_for(s, False, True, True, expiration)
138
139 def _addresses_equal(af, a1, a2):
140     # Convert the first value of the tuple, which is a textual format
141     # address into binary form, so that we are not confused by different
142     # textual representations of the same address
143     n1 = dns.inet.inet_pton(af, a1[0])
144     n2 = dns.inet.inet_pton(af, a2[0])
145     return n1 == n2 and a1[1:] == a2[1:]
146
147 def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
148         ignore_unexpected=False, one_rr_per_rrset=False):
149     """Return the response obtained after sending a query via UDP.
150
151     @param q: the query
152     @type q: dns.message.Message
153     @param where: where to send the message
154     @type where: string containing an IPv4 or IPv6 address
155     @param timeout: The number of seconds to wait before the query times out.
156     If None, the default, wait forever.
157     @type timeout: float
158     @param port: The port to which to send the message.  The default is 53.
159     @type port: int
160     @param af: the address family to use.  The default is None, which
161     causes the address family to use to be inferred from the form of of where.
162     If the inference attempt fails, AF_INET is used.
163     @type af: int
164     @rtype: dns.message.Message object
165     @param source: source address.  The default is the IPv4 wildcard address.
166     @type source: string
167     @param source_port: The port from which to send the message.
168     The default is 0.
169     @type source_port: int
170     @param ignore_unexpected: If True, ignore responses from unexpected
171     sources.  The default is False.
172     @type ignore_unexpected: bool
173     @param one_rr_per_rrset: Put each RR into its own RRset
174     @type one_rr_per_rrset: bool
175     """
176
177     wire = q.to_wire()
178     if af is None:
179         try:
180             af = dns.inet.af_for_address(where)
181         except:
182             af = dns.inet.AF_INET
183     if af == dns.inet.AF_INET:
184         destination = (where, port)
185         if source is not None:
186             source = (source, source_port)
187     elif af == dns.inet.AF_INET6:
188         destination = (where, port, 0, 0)
189         if source is not None:
190             source = (source, source_port, 0, 0)
191     s = socket.socket(af, socket.SOCK_DGRAM, 0)
192     try:
193         expiration = _compute_expiration(timeout)
194         s.setblocking(0)
195         if source is not None:
196             s.bind(source)
197         _wait_for_writable(s, expiration)
198         s.sendto(wire, destination)
199         while 1:
200             _wait_for_readable(s, expiration)
201             (wire, from_address) = s.recvfrom(65535)
202             if _addresses_equal(af, from_address, destination) or \
203                     (dns.inet.is_multicast(where) and \
204                          from_address[1:] == destination[1:]):
205                 break
206             if not ignore_unexpected:
207                 raise UnexpectedSource('got a response from '
208                                        '%s instead of %s' % (from_address,
209                                                              destination))
210     finally:
211         s.close()
212     r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
213                               one_rr_per_rrset=one_rr_per_rrset)
214     if not q.is_response(r):
215         raise BadResponse
216     return r
217
218 def _net_read(sock, count, expiration):
219     """Read the specified number of bytes from sock.  Keep trying until we
220     either get the desired amount, or we hit EOF.
221     A Timeout exception will be raised if the operation is not completed
222     by the expiration time.
223     """
224     s = ''
225     while count > 0:
226         _wait_for_readable(sock, expiration)
227         n = sock.recv(count)
228         if n == '':
229             raise EOFError
230         count = count - len(n)
231         s = s + n
232     return s
233
234 def _net_write(sock, data, expiration):
235     """Write the specified data to the socket.
236     A Timeout exception will be raised if the operation is not completed
237     by the expiration time.
238     """
239     current = 0
240     l = len(data)
241     while current < l:
242         _wait_for_writable(sock, expiration)
243         current += sock.send(data[current:])
244
245 def _connect(s, address):
246     try:
247         s.connect(address)
248     except socket.error:
249         (ty, v) = sys.exc_info()[:2]
250         if v[0] != errno.EINPROGRESS and \
251                v[0] != errno.EWOULDBLOCK and \
252                v[0] != errno.EALREADY:
253             raise v
254
255 def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
256         one_rr_per_rrset=False):
257     """Return the response obtained after sending a query via TCP.
258
259     @param q: the query
260     @type q: dns.message.Message object
261     @param where: where to send the message
262     @type where: string containing an IPv4 or IPv6 address
263     @param timeout: The number of seconds to wait before the query times out.
264     If None, the default, wait forever.
265     @type timeout: float
266     @param port: The port to which to send the message.  The default is 53.
267     @type port: int
268     @param af: the address family to use.  The default is None, which
269     causes the address family to use to be inferred from the form of of where.
270     If the inference attempt fails, AF_INET is used.
271     @type af: int
272     @rtype: dns.message.Message object
273     @param source: source address.  The default is the IPv4 wildcard address.
274     @type source: string
275     @param source_port: The port from which to send the message.
276     The default is 0.
277     @type source_port: int
278     @param one_rr_per_rrset: Put each RR into its own RRset
279     @type one_rr_per_rrset: bool
280     """
281
282     wire = q.to_wire()
283     if af is None:
284         try:
285             af = dns.inet.af_for_address(where)
286         except:
287             af = dns.inet.AF_INET
288     if af == dns.inet.AF_INET:
289         destination = (where, port)
290         if source is not None:
291             source = (source, source_port)
292     elif af == dns.inet.AF_INET6:
293         destination = (where, port, 0, 0)
294         if source is not None:
295             source = (source, source_port, 0, 0)
296     s = socket.socket(af, socket.SOCK_STREAM, 0)
297     try:
298         expiration = _compute_expiration(timeout)
299         s.setblocking(0)
300         if source is not None:
301             s.bind(source)
302         _connect(s, destination)
303
304         l = len(wire)
305
306         # copying the wire into tcpmsg is inefficient, but lets us
307         # avoid writev() or doing a short write that would get pushed
308         # onto the net
309         tcpmsg = struct.pack("!H", l) + wire
310         _net_write(s, tcpmsg, expiration)
311         ldata = _net_read(s, 2, expiration)
312         (l,) = struct.unpack("!H", ldata)
313         wire = _net_read(s, l, expiration)
314     finally:
315         s.close()
316     r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
317                               one_rr_per_rrset=one_rr_per_rrset)
318     if not q.is_response(r):
319         raise BadResponse
320     return r
321
322 def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
323         timeout=None, port=53, keyring=None, keyname=None, relativize=True,
324         af=None, lifetime=None, source=None, source_port=0, serial=0,
325         use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
326     """Return a generator for the responses to a zone transfer.
327
328     @param where: where to send the message
329     @type where: string containing an IPv4 or IPv6 address
330     @param zone: The name of the zone to transfer
331     @type zone: dns.name.Name object or string
332     @param rdtype: The type of zone transfer.  The default is
333     dns.rdatatype.AXFR.
334     @type rdtype: int or string
335     @param rdclass: The class of the zone transfer.  The default is
336     dns.rdatatype.IN.
337     @type rdclass: int or string
338     @param timeout: The number of seconds to wait for each response message.
339     If None, the default, wait forever.
340     @type timeout: float
341     @param port: The port to which to send the message.  The default is 53.
342     @type port: int
343     @param keyring: The TSIG keyring to use
344     @type keyring: dict
345     @param keyname: The name of the TSIG key to use
346     @type keyname: dns.name.Name object or string
347     @param relativize: If True, all names in the zone will be relativized to
348     the zone origin.  It is essential that the relativize setting matches
349     the one specified to dns.zone.from_xfr().
350     @type relativize: bool
351     @param af: the address family to use.  The default is None, which
352     causes the address family to use to be inferred from the form of of where.
353     If the inference attempt fails, AF_INET is used.
354     @type af: int
355     @param lifetime: The total number of seconds to spend doing the transfer.
356     If None, the default, then there is no limit on the time the transfer may
357     take.
358     @type lifetime: float
359     @rtype: generator of dns.message.Message objects.
360     @param source: source address.  The default is the IPv4 wildcard address.
361     @type source: string
362     @param source_port: The port from which to send the message.
363     The default is 0.
364     @type source_port: int
365     @param serial: The SOA serial number to use as the base for an IXFR diff
366     sequence (only meaningful if rdtype == dns.rdatatype.IXFR).
367     @type serial: int
368     @param use_udp: Use UDP (only meaningful for IXFR)
369     @type use_udp: bool
370     @param keyalgorithm: The TSIG algorithm to use; defaults to
371     dns.tsig.default_algorithm
372     @type keyalgorithm: string
373     """
374
375     if isinstance(zone, (str, unicode)):
376         zone = dns.name.from_text(zone)
377     if isinstance(rdtype, (str, unicode)):
378         rdtype = dns.rdatatype.from_text(rdtype)
379     q = dns.message.make_query(zone, rdtype, rdclass)
380     if rdtype == dns.rdatatype.IXFR:
381         rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
382                                     '. . %u 0 0 0 0' % serial)
383         q.authority.append(rrset)
384     if not keyring is None:
385         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
386     wire = q.to_wire()
387     if af is None:
388         try:
389             af = dns.inet.af_for_address(where)
390         except:
391             af = dns.inet.AF_INET
392     if af == dns.inet.AF_INET:
393         destination = (where, port)
394         if source is not None:
395             source = (source, source_port)
396     elif af == dns.inet.AF_INET6:
397         destination = (where, port, 0, 0)
398         if source is not None:
399             source = (source, source_port, 0, 0)
400     if use_udp:
401         if rdtype != dns.rdatatype.IXFR:
402             raise ValueError('cannot do a UDP AXFR')
403         s = socket.socket(af, socket.SOCK_DGRAM, 0)
404     else:
405         s = socket.socket(af, socket.SOCK_STREAM, 0)
406     s.setblocking(0)
407     if source is not None:
408         s.bind(source)
409     expiration = _compute_expiration(lifetime)
410     _connect(s, destination)
411     l = len(wire)
412     if use_udp:
413         _wait_for_writable(s, expiration)
414         s.send(wire)
415     else:
416         tcpmsg = struct.pack("!H", l) + wire
417         _net_write(s, tcpmsg, expiration)
418     done = False
419     soa_rrset = None
420     soa_count = 0
421     if relativize:
422         origin = zone
423         oname = dns.name.empty
424     else:
425         origin = None
426         oname = zone
427     tsig_ctx = None
428     first = True
429     while not done:
430         mexpiration = _compute_expiration(timeout)
431         if mexpiration is None or mexpiration > expiration:
432             mexpiration = expiration
433         if use_udp:
434             _wait_for_readable(s, expiration)
435             (wire, from_address) = s.recvfrom(65535)
436         else:
437             ldata = _net_read(s, 2, mexpiration)
438             (l,) = struct.unpack("!H", ldata)
439             wire = _net_read(s, l, mexpiration)
440         r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
441                                   xfr=True, origin=origin, tsig_ctx=tsig_ctx,
442                                   multi=True, first=first,
443                                   one_rr_per_rrset=(rdtype==dns.rdatatype.IXFR))
444         tsig_ctx = r.tsig_ctx
445         first = False
446         answer_index = 0
447         delete_mode = False
448         expecting_SOA = False
449         if soa_rrset is None:
450             if not r.answer or r.answer[0].name != oname:
451                 raise dns.exception.FormError
452             rrset = r.answer[0]
453             if rrset.rdtype != dns.rdatatype.SOA:
454                 raise dns.exception.FormError("first RRset is not an SOA")
455             answer_index = 1
456             soa_rrset = rrset.copy()
457             if rdtype == dns.rdatatype.IXFR:
458                 if soa_rrset[0].serial == serial:
459                     #
460                     # We're already up-to-date.
461                     #
462                     done = True
463                 else:
464                     expecting_SOA = True
465         #
466         # Process SOAs in the answer section (other than the initial
467         # SOA in the first message).
468         #
469         for rrset in r.answer[answer_index:]:
470             if done:
471                 raise dns.exception.FormError("answers after final SOA")
472             if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
473                 if expecting_SOA:
474                     if rrset[0].serial != serial:
475                         raise dns.exception.FormError("IXFR base serial mismatch")
476                     expecting_SOA = False
477                 elif rdtype == dns.rdatatype.IXFR:
478                     delete_mode = not delete_mode
479                 if rrset == soa_rrset and not delete_mode:
480                     done = True
481             elif expecting_SOA:
482                 #
483                 # We made an IXFR request and are expecting another
484                 # SOA RR, but saw something else, so this must be an
485                 # AXFR response.
486                 #
487                 rdtype = dns.rdatatype.AXFR
488                 expecting_SOA = False
489         if done and q.keyring and not r.had_tsig:
490             raise dns.exception.FormError("missing TSIG")
491         yield r
492     s.close()