s4-python: Move dnspython to lib/, like the other Python modules
[idra/samba.git] / lib / dnspython / dns / query.py
1 # Copyright (C) 2003-2007, 2009, 2010 Nominum, Inc.
2 #
3 # Permission to use, copy, modify, and distribute this software and its
4 # documentation for any purpose with or without fee is hereby granted,
5 # provided that the above copyright notice and this permission notice
6 # appear in all copies.
7 #
8 # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16 """Talk to a DNS server."""
17
18 from __future__ import generators
19
20 import errno
21 import select
22 import socket
23 import struct
24 import sys
25 import time
26
27 import dns.exception
28 import dns.inet
29 import dns.name
30 import dns.message
31 import dns.rdataclass
32 import dns.rdatatype
33
34 class UnexpectedSource(dns.exception.DNSException):
35     """Raised if a query response comes from an unexpected address or port."""
36     pass
37
38 class BadResponse(dns.exception.FormError):
39     """Raised if a query response does not respond to the question asked."""
40     pass
41
42 def _compute_expiration(timeout):
43     if timeout is None:
44         return None
45     else:
46         return time.time() + timeout
47
48 def _wait_for(ir, iw, ix, expiration):
49     done = False
50     while not done:
51         if expiration is None:
52             timeout = None
53         else:
54             timeout = expiration - time.time()
55             if timeout <= 0.0:
56                 raise dns.exception.Timeout
57         try:
58             if timeout is None:
59                 (r, w, x) = select.select(ir, iw, ix)
60             else:
61                 (r, w, x) = select.select(ir, iw, ix, timeout)
62         except select.error, e:
63             if e.args[0] != errno.EINTR:
64                 raise e
65         done = True
66         if len(r) == 0 and len(w) == 0 and len(x) == 0:
67             raise dns.exception.Timeout
68
69 def _wait_for_readable(s, expiration):
70     _wait_for([s], [], [s], expiration)
71
72 def _wait_for_writable(s, expiration):
73     _wait_for([], [s], [s], expiration)
74
75 def _addresses_equal(af, a1, a2):
76     # Convert the first value of the tuple, which is a textual format
77     # address into binary form, so that we are not confused by different
78     # textual representations of the same address
79     n1 = dns.inet.inet_pton(af, a1[0])
80     n2 = dns.inet.inet_pton(af, a2[0])
81     return n1 == n2 and a1[1:] == a2[1:]
82
83 def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
84         ignore_unexpected=False, one_rr_per_rrset=False):
85     """Return the response obtained after sending a query via UDP.
86
87     @param q: the query
88     @type q: dns.message.Message
89     @param where: where to send the message
90     @type where: string containing an IPv4 or IPv6 address
91     @param timeout: The number of seconds to wait before the query times out.
92     If None, the default, wait forever.
93     @type timeout: float
94     @param port: The port to which to send the message.  The default is 53.
95     @type port: int
96     @param af: the address family to use.  The default is None, which
97     causes the address family to use to be inferred from the form of of where.
98     If the inference attempt fails, AF_INET is used.
99     @type af: int
100     @rtype: dns.message.Message object
101     @param source: source address.  The default is the IPv4 wildcard address.
102     @type source: string
103     @param source_port: The port from which to send the message.
104     The default is 0.
105     @type source_port: int
106     @param ignore_unexpected: If True, ignore responses from unexpected
107     sources.  The default is False.
108     @type ignore_unexpected: bool
109     @param one_rr_per_rrset: Put each RR into its own RRset
110     @type one_rr_per_rrset: bool
111     """
112
113     wire = q.to_wire()
114     if af is None:
115         try:
116             af = dns.inet.af_for_address(where)
117         except:
118             af = dns.inet.AF_INET
119     if af == dns.inet.AF_INET:
120         destination = (where, port)
121         if source is not None:
122             source = (source, source_port)
123     elif af == dns.inet.AF_INET6:
124         destination = (where, port, 0, 0)
125         if source is not None:
126             source = (source, source_port, 0, 0)
127     s = socket.socket(af, socket.SOCK_DGRAM, 0)
128     try:
129         expiration = _compute_expiration(timeout)
130         s.setblocking(0)
131         if source is not None:
132             s.bind(source)
133         _wait_for_writable(s, expiration)
134         s.sendto(wire, destination)
135         while 1:
136             _wait_for_readable(s, expiration)
137             (wire, from_address) = s.recvfrom(65535)
138             if _addresses_equal(af, from_address, destination) or \
139                     (dns.inet.is_multicast(where) and \
140                          from_address[1:] == destination[1:]):
141                 break
142             if not ignore_unexpected:
143                 raise UnexpectedSource('got a response from '
144                                        '%s instead of %s' % (from_address,
145                                                              destination))
146     finally:
147         s.close()
148     r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
149                               one_rr_per_rrset=one_rr_per_rrset)
150     if not q.is_response(r):
151         raise BadResponse
152     return r
153
154 def _net_read(sock, count, expiration):
155     """Read the specified number of bytes from sock.  Keep trying until we
156     either get the desired amount, or we hit EOF.
157     A Timeout exception will be raised if the operation is not completed
158     by the expiration time.
159     """
160     s = ''
161     while count > 0:
162         _wait_for_readable(sock, expiration)
163         n = sock.recv(count)
164         if n == '':
165             raise EOFError
166         count = count - len(n)
167         s = s + n
168     return s
169
170 def _net_write(sock, data, expiration):
171     """Write the specified data to the socket.
172     A Timeout exception will be raised if the operation is not completed
173     by the expiration time.
174     """
175     current = 0
176     l = len(data)
177     while current < l:
178         _wait_for_writable(sock, expiration)
179         current += sock.send(data[current:])
180
181 def _connect(s, address):
182     try:
183         s.connect(address)
184     except socket.error:
185         (ty, v) = sys.exc_info()[:2]
186         if v[0] != errno.EINPROGRESS and \
187                v[0] != errno.EWOULDBLOCK and \
188                v[0] != errno.EALREADY:
189             raise v
190
191 def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
192         one_rr_per_rrset=False):
193     """Return the response obtained after sending a query via TCP.
194
195     @param q: the query
196     @type q: dns.message.Message object
197     @param where: where to send the message
198     @type where: string containing an IPv4 or IPv6 address
199     @param timeout: The number of seconds to wait before the query times out.
200     If None, the default, wait forever.
201     @type timeout: float
202     @param port: The port to which to send the message.  The default is 53.
203     @type port: int
204     @param af: the address family to use.  The default is None, which
205     causes the address family to use to be inferred from the form of of where.
206     If the inference attempt fails, AF_INET is used.
207     @type af: int
208     @rtype: dns.message.Message object
209     @param source: source address.  The default is the IPv4 wildcard address.
210     @type source: string
211     @param source_port: The port from which to send the message.
212     The default is 0.
213     @type source_port: int
214     @param one_rr_per_rrset: Put each RR into its own RRset
215     @type one_rr_per_rrset: bool
216     """
217
218     wire = q.to_wire()
219     if af is None:
220         try:
221             af = dns.inet.af_for_address(where)
222         except:
223             af = dns.inet.AF_INET
224     if af == dns.inet.AF_INET:
225         destination = (where, port)
226         if source is not None:
227             source = (source, source_port)
228     elif af == dns.inet.AF_INET6:
229         destination = (where, port, 0, 0)
230         if source is not None:
231             source = (source, source_port, 0, 0)
232     s = socket.socket(af, socket.SOCK_STREAM, 0)
233     try:
234         expiration = _compute_expiration(timeout)
235         s.setblocking(0)
236         if source is not None:
237             s.bind(source)
238         _connect(s, destination)
239
240         l = len(wire)
241
242         # copying the wire into tcpmsg is inefficient, but lets us
243         # avoid writev() or doing a short write that would get pushed
244         # onto the net
245         tcpmsg = struct.pack("!H", l) + wire
246         _net_write(s, tcpmsg, expiration)
247         ldata = _net_read(s, 2, expiration)
248         (l,) = struct.unpack("!H", ldata)
249         wire = _net_read(s, l, expiration)
250     finally:
251         s.close()
252     r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
253                               one_rr_per_rrset=one_rr_per_rrset)
254     if not q.is_response(r):
255         raise BadResponse
256     return r
257
258 def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
259         timeout=None, port=53, keyring=None, keyname=None, relativize=True,
260         af=None, lifetime=None, source=None, source_port=0, serial=0,
261         use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
262     """Return a generator for the responses to a zone transfer.
263
264     @param where: where to send the message
265     @type where: string containing an IPv4 or IPv6 address
266     @param zone: The name of the zone to transfer
267     @type zone: dns.name.Name object or string
268     @param rdtype: The type of zone transfer.  The default is
269     dns.rdatatype.AXFR.
270     @type rdtype: int or string
271     @param rdclass: The class of the zone transfer.  The default is
272     dns.rdatatype.IN.
273     @type rdclass: int or string
274     @param timeout: The number of seconds to wait for each response message.
275     If None, the default, wait forever.
276     @type timeout: float
277     @param port: The port to which to send the message.  The default is 53.
278     @type port: int
279     @param keyring: The TSIG keyring to use
280     @type keyring: dict
281     @param keyname: The name of the TSIG key to use
282     @type keyname: dns.name.Name object or string
283     @param relativize: If True, all names in the zone will be relativized to
284     the zone origin.  It is essential that the relativize setting matches
285     the one specified to dns.zone.from_xfr().
286     @type relativize: bool
287     @param af: the address family to use.  The default is None, which
288     causes the address family to use to be inferred from the form of of where.
289     If the inference attempt fails, AF_INET is used.
290     @type af: int
291     @param lifetime: The total number of seconds to spend doing the transfer.
292     If None, the default, then there is no limit on the time the transfer may
293     take.
294     @type lifetime: float
295     @rtype: generator of dns.message.Message objects.
296     @param source: source address.  The default is the IPv4 wildcard address.
297     @type source: string
298     @param source_port: The port from which to send the message.
299     The default is 0.
300     @type source_port: int
301     @param serial: The SOA serial number to use as the base for an IXFR diff
302     sequence (only meaningful if rdtype == dns.rdatatype.IXFR).
303     @type serial: int
304     @param use_udp: Use UDP (only meaningful for IXFR)
305     @type use_udp: bool
306     @param keyalgorithm: The TSIG algorithm to use; defaults to
307     dns.tsig.default_algorithm
308     @type keyalgorithm: string
309     """
310
311     if isinstance(zone, (str, unicode)):
312         zone = dns.name.from_text(zone)
313     if isinstance(rdtype, str):
314         rdtype = dns.rdatatype.from_text(rdtype)
315     q = dns.message.make_query(zone, rdtype, rdclass)
316     if rdtype == dns.rdatatype.IXFR:
317         rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
318                                     '. . %u 0 0 0 0' % serial)
319         q.authority.append(rrset)
320     if not keyring is None:
321         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
322     wire = q.to_wire()
323     if af is None:
324         try:
325             af = dns.inet.af_for_address(where)
326         except:
327             af = dns.inet.AF_INET
328     if af == dns.inet.AF_INET:
329         destination = (where, port)
330         if source is not None:
331             source = (source, source_port)
332     elif af == dns.inet.AF_INET6:
333         destination = (where, port, 0, 0)
334         if source is not None:
335             source = (source, source_port, 0, 0)
336     if use_udp:
337         if rdtype != dns.rdatatype.IXFR:
338             raise ValueError('cannot do a UDP AXFR')
339         s = socket.socket(af, socket.SOCK_DGRAM, 0)
340     else:
341         s = socket.socket(af, socket.SOCK_STREAM, 0)
342     s.setblocking(0)
343     if source is not None:
344         s.bind(source)
345     expiration = _compute_expiration(lifetime)
346     _connect(s, destination)
347     l = len(wire)
348     if use_udp:
349         _wait_for_writable(s, expiration)
350         s.send(wire)
351     else:
352         tcpmsg = struct.pack("!H", l) + wire
353         _net_write(s, tcpmsg, expiration)
354     done = False
355     soa_rrset = None
356     soa_count = 0
357     if relativize:
358         origin = zone
359         oname = dns.name.empty
360     else:
361         origin = None
362         oname = zone
363     tsig_ctx = None
364     first = True
365     while not done:
366         mexpiration = _compute_expiration(timeout)
367         if mexpiration is None or mexpiration > expiration:
368             mexpiration = expiration
369         if use_udp:
370             _wait_for_readable(s, expiration)
371             (wire, from_address) = s.recvfrom(65535)
372         else:
373             ldata = _net_read(s, 2, mexpiration)
374             (l,) = struct.unpack("!H", ldata)
375             wire = _net_read(s, l, mexpiration)
376         r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
377                                   xfr=True, origin=origin, tsig_ctx=tsig_ctx,
378                                   multi=True, first=first,
379                                   one_rr_per_rrset=(rdtype==dns.rdatatype.IXFR))
380         tsig_ctx = r.tsig_ctx
381         first = False
382         answer_index = 0
383         delete_mode = False
384         expecting_SOA = False
385         if soa_rrset is None:
386             if not r.answer or r.answer[0].name != oname:
387                 raise dns.exception.FormError
388             rrset = r.answer[0]
389             if rrset.rdtype != dns.rdatatype.SOA:
390                 raise dns.exception.FormError("first RRset is not an SOA")
391             answer_index = 1
392             soa_rrset = rrset.copy()
393             if rdtype == dns.rdatatype.IXFR:
394                 if soa_rrset[0].serial == serial:
395                     #
396                     # We're already up-to-date.
397                     #
398                     done = True
399                 else:
400                     expecting_SOA = True
401         #
402         # Process SOAs in the answer section (other than the initial
403         # SOA in the first message).
404         #
405         for rrset in r.answer[answer_index:]:
406             if done:
407                 raise dns.exception.FormError("answers after final SOA")
408             if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
409                 if expecting_SOA:
410                     if rrset[0].serial != serial:
411                         raise dns.exception.FormError("IXFR base serial mismatch")
412                     expecting_SOA = False
413                 elif rdtype == dns.rdatatype.IXFR:
414                     delete_mode = not delete_mode
415                 if rrset == soa_rrset and not delete_mode:
416                     done = True
417             elif expecting_SOA:
418                 #
419                 # We made an IXFR request and are expecting another
420                 # SOA RR, but saw something else, so this must be an
421                 # AXFR response.
422                 #
423                 rdtype = dns.rdatatype.AXFR
424                 expecting_SOA = False
425         if done and q.keyring and not r.had_tsig:
426             raise dns.exception.FormError("missing TSIG")
427         yield r
428     s.close()