ctdb-scripts: Do not de-duplicate the interfaces list
[samba.git] / python / samba / tests / dns_packet.py
1 # Tests of malformed DNS packets
2 # Copyright (C) Catalyst.NET ltd
3 #
4 # written by Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19 """Sanity tests for DNS and NBT server parsing.
20
21 We don't use a proper client library so we can make improper packets.
22 """
23
24 import os
25 import struct
26 import socket
27 import select
28 from samba.dcerpc import dns, nbt
29
30 from samba.tests import TestCase
31
32
33 def _msg_id():
34     while True:
35         for i in range(1, 0xffff):
36             yield i
37
38
39 SERVER = os.environ['SERVER_IP']
40 SERVER_NAME = f"{os.environ['SERVER']}.{os.environ['REALM']}"
41 TIMEOUT = 0.5
42
43
44 def encode_netbios_bytes(chars):
45     """Even RFC 1002 uses distancing quotes when calling this "compression"."""
46     out = []
47     chars = (chars + b'                   ')[:16]
48     for c in chars:
49         out.append((c >> 4) + 65)
50         out.append((c & 15) + 65)
51     return bytes(out)
52
53
54 class TestDnsPacketBase(TestCase):
55     msg_id = _msg_id()
56
57     def tearDown(self):
58         # we need to ensure the DNS server is responsive before
59         # continuing.
60         for i in range(40):
61             ok = self._known_good_query()
62             if ok:
63                 return
64         print(f"the server is STILL unresponsive after {40 * TIMEOUT} seconds")
65
66     def decode_reply(self, data):
67         header = data[:12]
68         id, flags, n_q, n_a, n_rec, n_exta = struct.unpack('!6H',
69                                                            header)
70         return {
71             'rcode': flags & 0xf
72         }
73
74     def construct_query(self, names):
75         """Create a query packet containing one query record.
76
77         *names* is either a single string name in the usual dotted
78         form, or a list of names. In the latter case, each name can
79         be a dotted string or a list of byte components, which allows
80         dots in components. Where I say list, I mean non-string
81         iterable.
82
83         Examples:
84
85         # these 3 are all the same
86         "example.com"
87         ["example.com"]
88         [[b"example", b"com"]]
89
90         # this is three names in the same request
91         ["example.com",
92          [b"example", b"com", b"..!"],
93          (b"first component", b" 2nd component")]
94         """
95         header = struct.pack('!6H',
96                              next(self.msg_id),
97                              0x0100,       # query, with recursion
98                              len(names),   # number of queries
99                              0x0000,       # no answers
100                              0x0000,       # no records
101                              0x0000,       # no extra records
102         )
103         tail = struct.pack('!BHH',
104                            0x00,         # root node
105                            self.qtype,
106                            0x0001,       # class IN-ternet
107         )
108         encoded_bits = []
109         for name in names:
110             if isinstance(name, str):
111                 bits = name.encode('utf8').split(b'.')
112             else:
113                 bits = name
114
115             for b in bits:
116                 encoded_bits.append(b'%c%s' % (len(b), b))
117             encoded_bits.append(tail)
118
119         return header + b''.join(encoded_bits)
120
121     def _test_query(self, names=(), expected_rcode=None):
122
123         if isinstance(names, str):
124             names = [names]
125
126         packet = self.construct_query(names)
127         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
128         s.sendto(packet, self.server)
129         r, _, _ = select.select([s], [], [], TIMEOUT)
130         s.close()
131         # It is reasonable to not reply to these packets (Windows
132         # doesn't), but it is not reasonable to render the server
133         # unresponsive.
134         if r != [s]:
135             ok = self._known_good_query()
136             self.assertTrue(ok, f"the server is unresponsive")
137
138     def _known_good_query(self):
139         if self.server[1] == 53:
140             name = SERVER_NAME
141             expected_rcode = dns.DNS_RCODE_OK
142         else:
143             name = [encode_netbios_bytes(b'nxdomain'), b'nxdomain']
144             expected_rcode = nbt.NBT_RCODE_NAM
145
146         packet = self.construct_query([name])
147         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
148         s.sendto(packet, self.server)
149         r, _, _ = select.select([s], [], [], TIMEOUT)
150         if r != [s]:
151             s.close()
152             return False
153
154         data, addr = s.recvfrom(4096)
155         s.close()
156         rcode = self.decode_reply(data)['rcode']
157         return expected_rcode == rcode
158
159     def _test_empty_packet(self):
160
161         packet = b""
162         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
163         s.sendto(packet, self.server)
164         s.close()
165
166         # It is reasonable not to reply to an empty packet
167         # but it is not reasonable to render the server
168         # unresponsive.
169         ok = self._known_good_query()
170         self.assertTrue(ok, f"the server is unresponsive")
171
172
173 class TestDnsPackets(TestDnsPacketBase):
174     server = (SERVER, 53)
175     qtype = 1     # dns type A
176
177     def _test_many_repeated_components(self, label, n, expected_rcode=None):
178         name = [label] * n
179         self._test_query([name],
180                          expected_rcode=expected_rcode)
181
182     def test_127_very_dotty_components(self):
183         label = b'.' * 63
184         self._test_many_repeated_components(label, 127)
185
186     def test_127_half_dotty_components(self):
187         label = b'x.' * 31 + b'x'
188         self._test_many_repeated_components(label, 127)
189
190     def test_empty_packet(self):
191         self._test_empty_packet()
192
193
194 class TestNbtPackets(TestDnsPacketBase):
195     server = (SERVER, 137)
196     qtype = 0x20  # NBT_QTYPE_NETBIOS
197
198     def _test_nbt_encode_query(self, names, *args, **kwargs):
199         if isinstance(names, str):
200             names = [names]
201
202         nbt_names = []
203         for name in names:
204             if isinstance(name, str):
205                 bits = name.encode('utf8').split(b'.')
206             else:
207                 bits = name
208
209             encoded = [encode_netbios_bytes(bits[0])]
210             encoded.extend(bits[1:])
211             nbt_names.append(encoded)
212
213         self._test_query(nbt_names, *args, **kwargs)
214
215     def _test_many_repeated_components(self, label, n, expected_rcode=None):
216         name = [label] * n
217         name[0] = encode_netbios_bytes(label)
218         self._test_query([name],
219                          expected_rcode=expected_rcode)
220
221     def test_127_very_dotty_components(self):
222         label = b'.' * 63
223         self._test_many_repeated_components(label, 127)
224
225     def test_127_half_dotty_components(self):
226         label = b'x.' * 31 + b'x'
227         self._test_many_repeated_components(label, 127)
228
229     def test_empty_packet(self):
230         self._test_empty_packet()