s4-python: Format to PEP8, simplify tests.
[idra/samba.git] / source4 / scripting / python / samba_external / dnspython / dns / renderer.py
1 # Copyright (C) 2001-2007, 2009, 2010 Nominum, Inc.
2 #
3 # Permission to use, copy, modify, and distribute this software and its
4 # documentation for any purpose with or without fee is hereby granted,
5 # provided that the above copyright notice and this permission notice
6 # appear in all copies.
7 #
8 # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16 """Help for building DNS wire format messages"""
17
18 import cStringIO
19 import struct
20 import random
21 import time
22
23 import dns.exception
24 import dns.tsig
25
26 QUESTION = 0
27 ANSWER = 1
28 AUTHORITY = 2
29 ADDITIONAL = 3
30
31 class Renderer(object):
32     """Helper class for building DNS wire-format messages.
33
34     Most applications can use the higher-level L{dns.message.Message}
35     class and its to_wire() method to generate wire-format messages.
36     This class is for those applications which need finer control
37     over the generation of messages.
38
39     Typical use::
40
41         r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512)
42         r.add_question(qname, qtype, qclass)
43         r.add_rrset(dns.renderer.ANSWER, rrset_1)
44         r.add_rrset(dns.renderer.ANSWER, rrset_2)
45         r.add_rrset(dns.renderer.AUTHORITY, ns_rrset)
46         r.add_edns(0, 0, 4096)
47         r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_1)
48         r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_2)
49         r.write_header()
50         r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac)
51         wire = r.get_wire()
52
53     @ivar output: where rendering is written
54     @type output: cStringIO.StringIO object
55     @ivar id: the message id
56     @type id: int
57     @ivar flags: the message flags
58     @type flags: int
59     @ivar max_size: the maximum size of the message
60     @type max_size: int
61     @ivar origin: the origin to use when rendering relative names
62     @type origin: dns.name.Name object
63     @ivar compress: the compression table
64     @type compress: dict
65     @ivar section: the section currently being rendered
66     @type section: int (dns.renderer.QUESTION, dns.renderer.ANSWER,
67     dns.renderer.AUTHORITY, or dns.renderer.ADDITIONAL)
68     @ivar counts: list of the number of RRs in each section
69     @type counts: int list of length 4
70     @ivar mac: the MAC of the rendered message (if TSIG was used)
71     @type mac: string
72     """
73
74     def __init__(self, id=None, flags=0, max_size=65535, origin=None):
75         """Initialize a new renderer.
76
77         @param id: the message id
78         @type id: int
79         @param flags: the DNS message flags
80         @type flags: int
81         @param max_size: the maximum message size; the default is 65535.
82         If rendering results in a message greater than I{max_size},
83         then L{dns.exception.TooBig} will be raised.
84         @type max_size: int
85         @param origin: the origin to use when rendering relative names
86         @type origin: dns.name.Namem or None.
87         """
88
89         self.output = cStringIO.StringIO()
90         if id is None:
91             self.id = random.randint(0, 65535)
92         else:
93             self.id = id
94         self.flags = flags
95         self.max_size = max_size
96         self.origin = origin
97         self.compress = {}
98         self.section = QUESTION
99         self.counts = [0, 0, 0, 0]
100         self.output.write('\x00' * 12)
101         self.mac = ''
102
103     def _rollback(self, where):
104         """Truncate the output buffer at offset I{where}, and remove any
105         compression table entries that pointed beyond the truncation
106         point.
107
108         @param where: the offset
109         @type where: int
110         """
111
112         self.output.seek(where)
113         self.output.truncate()
114         keys_to_delete = []
115         for k, v in self.compress.iteritems():
116             if v >= where:
117                 keys_to_delete.append(k)
118         for k in keys_to_delete:
119             del self.compress[k]
120
121     def _set_section(self, section):
122         """Set the renderer's current section.
123
124         Sections must be rendered order: QUESTION, ANSWER, AUTHORITY,
125         ADDITIONAL.  Sections may be empty.
126
127         @param section: the section
128         @type section: int
129         @raises dns.exception.FormError: an attempt was made to set
130         a section value less than the current section.
131         """
132
133         if self.section != section:
134             if self.section > section:
135                 raise dns.exception.FormError
136             self.section = section
137
138     def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
139         """Add a question to the message.
140
141         @param qname: the question name
142         @type qname: dns.name.Name
143         @param rdtype: the question rdata type
144         @type rdtype: int
145         @param rdclass: the question rdata class
146         @type rdclass: int
147         """
148
149         self._set_section(QUESTION)
150         before = self.output.tell()
151         qname.to_wire(self.output, self.compress, self.origin)
152         self.output.write(struct.pack("!HH", rdtype, rdclass))
153         after = self.output.tell()
154         if after >= self.max_size:
155             self._rollback(before)
156             raise dns.exception.TooBig
157         self.counts[QUESTION] += 1
158
159     def add_rrset(self, section, rrset, **kw):
160         """Add the rrset to the specified section.
161
162         Any keyword arguments are passed on to the rdataset's to_wire()
163         routine.
164
165         @param section: the section
166         @type section: int
167         @param rrset: the rrset
168         @type rrset: dns.rrset.RRset object
169         """
170
171         self._set_section(section)
172         before = self.output.tell()
173         n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
174         after = self.output.tell()
175         if after >= self.max_size:
176             self._rollback(before)
177             raise dns.exception.TooBig
178         self.counts[section] += n
179
180     def add_rdataset(self, section, name, rdataset, **kw):
181         """Add the rdataset to the specified section, using the specified
182         name as the owner name.
183
184         Any keyword arguments are passed on to the rdataset's to_wire()
185         routine.
186
187         @param section: the section
188         @type section: int
189         @param name: the owner name
190         @type name: dns.name.Name object
191         @param rdataset: the rdataset
192         @type rdataset: dns.rdataset.Rdataset object
193         """
194
195         self._set_section(section)
196         before = self.output.tell()
197         n = rdataset.to_wire(name, self.output, self.compress, self.origin,
198                              **kw)
199         after = self.output.tell()
200         if after >= self.max_size:
201             self._rollback(before)
202             raise dns.exception.TooBig
203         self.counts[section] += n
204
205     def add_edns(self, edns, ednsflags, payload, options=None):
206         """Add an EDNS OPT record to the message.
207
208         @param edns: The EDNS level to use.
209         @type edns: int
210         @param ednsflags: EDNS flag values.
211         @type ednsflags: int
212         @param payload: The EDNS sender's payload field, which is the maximum
213         size of UDP datagram the sender can handle.
214         @type payload: int
215         @param options: The EDNS options list
216         @type options: list of dns.edns.Option instances
217         @see: RFC 2671
218         """
219
220         # make sure the EDNS version in ednsflags agrees with edns
221         ednsflags &= 0xFF00FFFFL
222         ednsflags |= (edns << 16)
223         self._set_section(ADDITIONAL)
224         before = self.output.tell()
225         self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload,
226                                       ednsflags, 0))
227         if not options is None:
228             lstart = self.output.tell()
229             for opt in options:
230                 stuff = struct.pack("!HH", opt.otype, 0)
231                 self.output.write(stuff)
232                 start = self.output.tell()
233                 opt.to_wire(self.output)
234                 end = self.output.tell()
235                 assert end - start < 65536
236                 self.output.seek(start - 2)
237                 stuff = struct.pack("!H", end - start)
238                 self.output.write(stuff)
239                 self.output.seek(0, 2)
240             lend = self.output.tell()
241             assert lend - lstart < 65536
242             self.output.seek(lstart - 2)
243             stuff = struct.pack("!H", lend - lstart)
244             self.output.write(stuff)
245             self.output.seek(0, 2)
246         after = self.output.tell()
247         if after >= self.max_size:
248             self._rollback(before)
249             raise dns.exception.TooBig
250         self.counts[ADDITIONAL] += 1
251
252     def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
253                  request_mac, algorithm=dns.tsig.default_algorithm):
254         """Add a TSIG signature to the message.
255
256         @param keyname: the TSIG key name
257         @type keyname: dns.name.Name object
258         @param secret: the secret to use
259         @type secret: string
260         @param fudge: TSIG time fudge
261         @type fudge: int
262         @param id: the message id to encode in the tsig signature
263         @type id: int
264         @param tsig_error: TSIG error code; default is 0.
265         @type tsig_error: int
266         @param other_data: TSIG other data.
267         @type other_data: string
268         @param request_mac: This message is a response to the request which
269         had the specified MAC.
270         @param algorithm: the TSIG algorithm to use
271         @type request_mac: string
272         """
273
274         self._set_section(ADDITIONAL)
275         before = self.output.tell()
276         s = self.output.getvalue()
277         (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
278                                                     keyname,
279                                                     secret,
280                                                     int(time.time()),
281                                                     fudge,
282                                                     id,
283                                                     tsig_error,
284                                                     other_data,
285                                                     request_mac,
286                                                     algorithm=algorithm)
287         keyname.to_wire(self.output, self.compress, self.origin)
288         self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
289                                       dns.rdataclass.ANY, 0, 0))
290         rdata_start = self.output.tell()
291         self.output.write(tsig_rdata)
292         after = self.output.tell()
293         assert after - rdata_start < 65536
294         if after >= self.max_size:
295             self._rollback(before)
296             raise dns.exception.TooBig
297         self.output.seek(rdata_start - 2)
298         self.output.write(struct.pack('!H', after - rdata_start))
299         self.counts[ADDITIONAL] += 1
300         self.output.seek(10)
301         self.output.write(struct.pack('!H', self.counts[ADDITIONAL]))
302         self.output.seek(0, 2)
303
304     def write_header(self):
305         """Write the DNS message header.
306
307         Writing the DNS message header is done asfter all sections
308         have been rendered, but before the optional TSIG signature
309         is added.
310         """
311
312         self.output.seek(0)
313         self.output.write(struct.pack('!HHHHHH', self.id, self.flags,
314                                       self.counts[0], self.counts[1],
315                                       self.counts[2], self.counts[3]))
316         self.output.seek(0, 2)
317
318     def get_wire(self):
319         """Return the wire format message.
320
321         @rtype: string
322         """
323
324         return self.output.getvalue()