Move dnspython to third_party.
[bbaumbach/samba-autobuild/.git] / third_party / dnspython / dns / name.py
1 # Copyright (C) 2001-2007, 2009-2011 Nominum, Inc.
2 #
3 # Permission to use, copy, modify, and distribute this software and its
4 # documentation for any purpose with or without fee is hereby granted,
5 # provided that the above copyright notice and this permission notice
6 # appear in all copies.
7 #
8 # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16 """DNS Names.
17
18 @var root: The DNS root name.
19 @type root: dns.name.Name object
20 @var empty: The empty DNS name.
21 @type empty: dns.name.Name object
22 """
23
24 import cStringIO
25 import struct
26 import sys
27
28 if sys.hexversion >= 0x02030000:
29     import encodings.idna
30
31 import dns.exception
32 import dns.wiredata
33
34 NAMERELN_NONE = 0
35 NAMERELN_SUPERDOMAIN = 1
36 NAMERELN_SUBDOMAIN = 2
37 NAMERELN_EQUAL = 3
38 NAMERELN_COMMONANCESTOR = 4
39
40 class EmptyLabel(dns.exception.SyntaxError):
41     """Raised if a label is empty."""
42     pass
43
44 class BadEscape(dns.exception.SyntaxError):
45     """Raised if an escaped code in a text format name is invalid."""
46     pass
47
48 class BadPointer(dns.exception.FormError):
49     """Raised if a compression pointer points forward instead of backward."""
50     pass
51
52 class BadLabelType(dns.exception.FormError):
53     """Raised if the label type of a wire format name is unknown."""
54     pass
55
56 class NeedAbsoluteNameOrOrigin(dns.exception.DNSException):
57     """Raised if an attempt is made to convert a non-absolute name to
58     wire when there is also a non-absolute (or missing) origin."""
59     pass
60
61 class NameTooLong(dns.exception.FormError):
62     """Raised if a name is > 255 octets long."""
63     pass
64
65 class LabelTooLong(dns.exception.SyntaxError):
66     """Raised if a label is > 63 octets long."""
67     pass
68
69 class AbsoluteConcatenation(dns.exception.DNSException):
70     """Raised if an attempt is made to append anything other than the
71     empty name to an absolute name."""
72     pass
73
74 class NoParent(dns.exception.DNSException):
75     """Raised if an attempt is made to get the parent of the root name
76     or the empty name."""
77     pass
78
79 _escaped = {
80     '"' : True,
81     '(' : True,
82     ')' : True,
83     '.' : True,
84     ';' : True,
85     '\\' : True,
86     '@' : True,
87     '$' : True
88     }
89
90 def _escapify(label):
91     """Escape the characters in label which need it.
92     @returns: the escaped string
93     @rtype: string"""
94     text = ''
95     for c in label:
96         if c in _escaped:
97             text += '\\' + c
98         elif ord(c) > 0x20 and ord(c) < 0x7F:
99             text += c
100         else:
101             text += '\\%03d' % ord(c)
102     return text
103
104 def _validate_labels(labels):
105     """Check for empty labels in the middle of a label sequence,
106     labels that are too long, and for too many labels.
107     @raises NameTooLong: the name as a whole is too long
108     @raises LabelTooLong: an individual label is too long
109     @raises EmptyLabel: a label is empty (i.e. the root label) and appears
110     in a position other than the end of the label sequence"""
111
112     l = len(labels)
113     total = 0
114     i = -1
115     j = 0
116     for label in labels:
117         ll = len(label)
118         total += ll + 1
119         if ll > 63:
120             raise LabelTooLong
121         if i < 0 and label == '':
122             i = j
123         j += 1
124     if total > 255:
125         raise NameTooLong
126     if i >= 0 and i != l - 1:
127         raise EmptyLabel
128
129 class Name(object):
130     """A DNS name.
131
132     The dns.name.Name class represents a DNS name as a tuple of labels.
133     Instances of the class are immutable.
134
135     @ivar labels: The tuple of labels in the name. Each label is a string of
136     up to 63 octets."""
137
138     __slots__ = ['labels']
139
140     def __init__(self, labels):
141         """Initialize a domain name from a list of labels.
142         @param labels: the labels
143         @type labels: any iterable whose values are strings
144         """
145
146         super(Name, self).__setattr__('labels', tuple(labels))
147         _validate_labels(self.labels)
148
149     def __setattr__(self, name, value):
150         raise TypeError("object doesn't support attribute assignment")
151
152     def is_absolute(self):
153         """Is the most significant label of this name the root label?
154         @rtype: bool
155         """
156
157         return len(self.labels) > 0 and self.labels[-1] == ''
158
159     def is_wild(self):
160         """Is this name wild?  (I.e. Is the least significant label '*'?)
161         @rtype: bool
162         """
163
164         return len(self.labels) > 0 and self.labels[0] == '*'
165
166     def __hash__(self):
167         """Return a case-insensitive hash of the name.
168         @rtype: int
169         """
170
171         h = 0L
172         for label in self.labels:
173             for c in label:
174                 h += ( h << 3 ) + ord(c.lower())
175         return int(h % sys.maxint)
176
177     def fullcompare(self, other):
178         """Compare two names, returning a 3-tuple (relation, order, nlabels).
179
180         I{relation} describes the relation ship beween the names,
181         and is one of: dns.name.NAMERELN_NONE,
182         dns.name.NAMERELN_SUPERDOMAIN, dns.name.NAMERELN_SUBDOMAIN,
183         dns.name.NAMERELN_EQUAL, or dns.name.NAMERELN_COMMONANCESTOR
184
185         I{order} is < 0 if self < other, > 0 if self > other, and ==
186         0 if self == other.  A relative name is always less than an
187         absolute name.  If both names have the same relativity, then
188         the DNSSEC order relation is used to order them.
189
190         I{nlabels} is the number of significant labels that the two names
191         have in common.
192         """
193
194         sabs = self.is_absolute()
195         oabs = other.is_absolute()
196         if sabs != oabs:
197             if sabs:
198                 return (NAMERELN_NONE, 1, 0)
199             else:
200                 return (NAMERELN_NONE, -1, 0)
201         l1 = len(self.labels)
202         l2 = len(other.labels)
203         ldiff = l1 - l2
204         if ldiff < 0:
205             l = l1
206         else:
207             l = l2
208
209         order = 0
210         nlabels = 0
211         namereln = NAMERELN_NONE
212         while l > 0:
213             l -= 1
214             l1 -= 1
215             l2 -= 1
216             label1 = self.labels[l1].lower()
217             label2 = other.labels[l2].lower()
218             if label1 < label2:
219                 order = -1
220                 if nlabels > 0:
221                     namereln = NAMERELN_COMMONANCESTOR
222                 return (namereln, order, nlabels)
223             elif label1 > label2:
224                 order = 1
225                 if nlabels > 0:
226                     namereln = NAMERELN_COMMONANCESTOR
227                 return (namereln, order, nlabels)
228             nlabels += 1
229         order = ldiff
230         if ldiff < 0:
231             namereln = NAMERELN_SUPERDOMAIN
232         elif ldiff > 0:
233             namereln = NAMERELN_SUBDOMAIN
234         else:
235             namereln = NAMERELN_EQUAL
236         return (namereln, order, nlabels)
237
238     def is_subdomain(self, other):
239         """Is self a subdomain of other?
240
241         The notion of subdomain includes equality.
242         @rtype: bool
243         """
244
245         (nr, o, nl) = self.fullcompare(other)
246         if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
247             return True
248         return False
249
250     def is_superdomain(self, other):
251         """Is self a superdomain of other?
252
253         The notion of subdomain includes equality.
254         @rtype: bool
255         """
256
257         (nr, o, nl) = self.fullcompare(other)
258         if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
259             return True
260         return False
261
262     def canonicalize(self):
263         """Return a name which is equal to the current name, but is in
264         DNSSEC canonical form.
265         @rtype: dns.name.Name object
266         """
267
268         return Name([x.lower() for x in self.labels])
269
270     def __eq__(self, other):
271         if isinstance(other, Name):
272             return self.fullcompare(other)[1] == 0
273         else:
274             return False
275
276     def __ne__(self, other):
277         if isinstance(other, Name):
278             return self.fullcompare(other)[1] != 0
279         else:
280             return True
281
282     def __lt__(self, other):
283         if isinstance(other, Name):
284             return self.fullcompare(other)[1] < 0
285         else:
286             return NotImplemented
287
288     def __le__(self, other):
289         if isinstance(other, Name):
290             return self.fullcompare(other)[1] <= 0
291         else:
292             return NotImplemented
293
294     def __ge__(self, other):
295         if isinstance(other, Name):
296             return self.fullcompare(other)[1] >= 0
297         else:
298             return NotImplemented
299
300     def __gt__(self, other):
301         if isinstance(other, Name):
302             return self.fullcompare(other)[1] > 0
303         else:
304             return NotImplemented
305
306     def __repr__(self):
307         return '<DNS name ' + self.__str__() + '>'
308
309     def __str__(self):
310         return self.to_text(False)
311
312     def to_text(self, omit_final_dot = False):
313         """Convert name to text format.
314         @param omit_final_dot: If True, don't emit the final dot (denoting the
315         root label) for absolute names.  The default is False.
316         @rtype: string
317         """
318
319         if len(self.labels) == 0:
320             return '@'
321         if len(self.labels) == 1 and self.labels[0] == '':
322             return '.'
323         if omit_final_dot and self.is_absolute():
324             l = self.labels[:-1]
325         else:
326             l = self.labels
327         s = '.'.join(map(_escapify, l))
328         return s
329
330     def to_unicode(self, omit_final_dot = False):
331         """Convert name to Unicode text format.
332
333         IDN ACE lables are converted to Unicode.
334
335         @param omit_final_dot: If True, don't emit the final dot (denoting the
336         root label) for absolute names.  The default is False.
337         @rtype: string
338         """
339
340         if len(self.labels) == 0:
341             return u'@'
342         if len(self.labels) == 1 and self.labels[0] == '':
343             return u'.'
344         if omit_final_dot and self.is_absolute():
345             l = self.labels[:-1]
346         else:
347             l = self.labels
348         s = u'.'.join([encodings.idna.ToUnicode(_escapify(x)) for x in l])
349         return s
350
351     def to_digestable(self, origin=None):
352         """Convert name to a format suitable for digesting in hashes.
353
354         The name is canonicalized and converted to uncompressed wire format.
355
356         @param origin: If the name is relative and origin is not None, then
357         origin will be appended to it.
358         @type origin: dns.name.Name object
359         @raises NeedAbsoluteNameOrOrigin: All names in wire format are
360         absolute.  If self is a relative name, then an origin must be supplied;
361         if it is missing, then this exception is raised
362         @rtype: string
363         """
364
365         if not self.is_absolute():
366             if origin is None or not origin.is_absolute():
367                 raise NeedAbsoluteNameOrOrigin
368             labels = list(self.labels)
369             labels.extend(list(origin.labels))
370         else:
371             labels = self.labels
372         dlabels = ["%s%s" % (chr(len(x)), x.lower()) for x in labels]
373         return ''.join(dlabels)
374
375     def to_wire(self, file = None, compress = None, origin = None):
376         """Convert name to wire format, possibly compressing it.
377
378         @param file: the file where the name is emitted (typically
379         a cStringIO file).  If None, a string containing the wire name
380         will be returned.
381         @type file: file or None
382         @param compress: The compression table.  If None (the default) names
383         will not be compressed.
384         @type compress: dict
385         @param origin: If the name is relative and origin is not None, then
386         origin will be appended to it.
387         @type origin: dns.name.Name object
388         @raises NeedAbsoluteNameOrOrigin: All names in wire format are
389         absolute.  If self is a relative name, then an origin must be supplied;
390         if it is missing, then this exception is raised
391         """
392
393         if file is None:
394             file = cStringIO.StringIO()
395             want_return = True
396         else:
397             want_return = False
398
399         if not self.is_absolute():
400             if origin is None or not origin.is_absolute():
401                 raise NeedAbsoluteNameOrOrigin
402             labels = list(self.labels)
403             labels.extend(list(origin.labels))
404         else:
405             labels = self.labels
406         i = 0
407         for label in labels:
408             n = Name(labels[i:])
409             i += 1
410             if not compress is None:
411                 pos = compress.get(n)
412             else:
413                 pos = None
414             if not pos is None:
415                 value = 0xc000 + pos
416                 s = struct.pack('!H', value)
417                 file.write(s)
418                 break
419             else:
420                 if not compress is None and len(n) > 1:
421                     pos = file.tell()
422                     if pos < 0xc000:
423                         compress[n] = pos
424                 l = len(label)
425                 file.write(chr(l))
426                 if l > 0:
427                     file.write(label)
428         if want_return:
429             return file.getvalue()
430
431     def __len__(self):
432         """The length of the name (in labels).
433         @rtype: int
434         """
435
436         return len(self.labels)
437
438     def __getitem__(self, index):
439         return self.labels[index]
440
441     def __getslice__(self, start, stop):
442         return self.labels[start:stop]
443
444     def __add__(self, other):
445         return self.concatenate(other)
446
447     def __sub__(self, other):
448         return self.relativize(other)
449
450     def split(self, depth):
451         """Split a name into a prefix and suffix at depth.
452
453         @param depth: the number of labels in the suffix
454         @type depth: int
455         @raises ValueError: the depth was not >= 0 and <= the length of the
456         name.
457         @returns: the tuple (prefix, suffix)
458         @rtype: tuple
459         """
460
461         l = len(self.labels)
462         if depth == 0:
463             return (self, dns.name.empty)
464         elif depth == l:
465             return (dns.name.empty, self)
466         elif depth < 0 or depth > l:
467             raise ValueError('depth must be >= 0 and <= the length of the name')
468         return (Name(self[: -depth]), Name(self[-depth :]))
469
470     def concatenate(self, other):
471         """Return a new name which is the concatenation of self and other.
472         @rtype: dns.name.Name object
473         @raises AbsoluteConcatenation: self is absolute and other is
474         not the empty name
475         """
476
477         if self.is_absolute() and len(other) > 0:
478             raise AbsoluteConcatenation
479         labels = list(self.labels)
480         labels.extend(list(other.labels))
481         return Name(labels)
482
483     def relativize(self, origin):
484         """If self is a subdomain of origin, return a new name which is self
485         relative to origin.  Otherwise return self.
486         @rtype: dns.name.Name object
487         """
488
489         if not origin is None and self.is_subdomain(origin):
490             return Name(self[: -len(origin)])
491         else:
492             return self
493
494     def derelativize(self, origin):
495         """If self is a relative name, return a new name which is the
496         concatenation of self and origin.  Otherwise return self.
497         @rtype: dns.name.Name object
498         """
499
500         if not self.is_absolute():
501             return self.concatenate(origin)
502         else:
503             return self
504
505     def choose_relativity(self, origin=None, relativize=True):
506         """Return a name with the relativity desired by the caller.  If
507         origin is None, then self is returned.  Otherwise, if
508         relativize is true the name is relativized, and if relativize is
509         false the name is derelativized.
510         @rtype: dns.name.Name object
511         """
512
513         if origin:
514             if relativize:
515                 return self.relativize(origin)
516             else:
517                 return self.derelativize(origin)
518         else:
519             return self
520
521     def parent(self):
522         """Return the parent of the name.
523         @rtype: dns.name.Name object
524         @raises NoParent: the name is either the root name or the empty name,
525         and thus has no parent.
526         """
527         if self == root or self == empty:
528             raise NoParent
529         return Name(self.labels[1:])
530
531 root = Name([''])
532 empty = Name([])
533
534 def from_unicode(text, origin = root):
535     """Convert unicode text into a Name object.
536
537     Lables are encoded in IDN ACE form.
538
539     @rtype: dns.name.Name object
540     """
541
542     if not isinstance(text, unicode):
543         raise ValueError("input to from_unicode() must be a unicode string")
544     if not (origin is None or isinstance(origin, Name)):
545         raise ValueError("origin must be a Name or None")
546     labels = []
547     label = u''
548     escaping = False
549     edigits = 0
550     total = 0
551     if text == u'@':
552         text = u''
553     if text:
554         if text == u'.':
555             return Name([''])   # no Unicode "u" on this constant!
556         for c in text:
557             if escaping:
558                 if edigits == 0:
559                     if c.isdigit():
560                         total = int(c)
561                         edigits += 1
562                     else:
563                         label += c
564                         escaping = False
565                 else:
566                     if not c.isdigit():
567                         raise BadEscape
568                     total *= 10
569                     total += int(c)
570                     edigits += 1
571                     if edigits == 3:
572                         escaping = False
573                         label += chr(total)
574             elif c == u'.' or c == u'\u3002' or \
575                  c == u'\uff0e' or c == u'\uff61':
576                 if len(label) == 0:
577                     raise EmptyLabel
578                 labels.append(encodings.idna.ToASCII(label))
579                 label = u''
580             elif c == u'\\':
581                 escaping = True
582                 edigits = 0
583                 total = 0
584             else:
585                 label += c
586         if escaping:
587             raise BadEscape
588         if len(label) > 0:
589             labels.append(encodings.idna.ToASCII(label))
590         else:
591             labels.append('')
592     if (len(labels) == 0 or labels[-1] != '') and not origin is None:
593         labels.extend(list(origin.labels))
594     return Name(labels)
595
596 def from_text(text, origin = root):
597     """Convert text into a Name object.
598     @rtype: dns.name.Name object
599     """
600
601     if not isinstance(text, str):
602         if isinstance(text, unicode) and sys.hexversion >= 0x02030000:
603             return from_unicode(text, origin)
604         else:
605             raise ValueError("input to from_text() must be a string")
606     if not (origin is None or isinstance(origin, Name)):
607         raise ValueError("origin must be a Name or None")
608     labels = []
609     label = ''
610     escaping = False
611     edigits = 0
612     total = 0
613     if text == '@':
614         text = ''
615     if text:
616         if text == '.':
617             return Name([''])
618         for c in text:
619             if escaping:
620                 if edigits == 0:
621                     if c.isdigit():
622                         total = int(c)
623                         edigits += 1
624                     else:
625                         label += c
626                         escaping = False
627                 else:
628                     if not c.isdigit():
629                         raise BadEscape
630                     total *= 10
631                     total += int(c)
632                     edigits += 1
633                     if edigits == 3:
634                         escaping = False
635                         label += chr(total)
636             elif c == '.':
637                 if len(label) == 0:
638                     raise EmptyLabel
639                 labels.append(label)
640                 label = ''
641             elif c == '\\':
642                 escaping = True
643                 edigits = 0
644                 total = 0
645             else:
646                 label += c
647         if escaping:
648             raise BadEscape
649         if len(label) > 0:
650             labels.append(label)
651         else:
652             labels.append('')
653     if (len(labels) == 0 or labels[-1] != '') and not origin is None:
654         labels.extend(list(origin.labels))
655     return Name(labels)
656
657 def from_wire(message, current):
658     """Convert possibly compressed wire format into a Name.
659     @param message: the entire DNS message
660     @type message: string
661     @param current: the offset of the beginning of the name from the start
662     of the message
663     @type current: int
664     @raises dns.name.BadPointer: a compression pointer did not point backwards
665     in the message
666     @raises dns.name.BadLabelType: an invalid label type was encountered.
667     @returns: a tuple consisting of the name that was read and the number
668     of bytes of the wire format message which were consumed reading it
669     @rtype: (dns.name.Name object, int) tuple
670     """
671
672     if not isinstance(message, str):
673         raise ValueError("input to from_wire() must be a byte string")
674     message = dns.wiredata.maybe_wrap(message)
675     labels = []
676     biggest_pointer = current
677     hops = 0
678     count = ord(message[current])
679     current += 1
680     cused = 1
681     while count != 0:
682         if count < 64:
683             labels.append(message[current : current + count].unwrap())
684             current += count
685             if hops == 0:
686                 cused += count
687         elif count >= 192:
688             current = (count & 0x3f) * 256 + ord(message[current])
689             if hops == 0:
690                 cused += 1
691             if current >= biggest_pointer:
692                 raise BadPointer
693             biggest_pointer = current
694             hops += 1
695         else:
696             raise BadLabelType
697         count = ord(message[current])
698         current += 1
699         if hops == 0:
700             cused += 1
701     labels.append('')
702     return (Name(labels), cused)