CVE-2018-14629: Tests to expose regression from dns cname loop fix
[samba.git] / python / samba / tests / dns.py
index e35cded7b8c0ce4c777d7ae795164dbfcfc09c07..1bf57e2aad046ed7deaf3ed8dde7479bd1e2d683 100644 (file)
@@ -869,6 +869,107 @@ class TestComplexQueries(DNSTest):
         max_recursion_depth = 20
         self.assertEquals(len(response.answers), max_recursion_depth)
 
+    # Make sure cname limit doesn't count other records.  This is a generic
+    # test called in tests below
+    def max_rec_test(self, rtype, rec_gen):
+        name = "limittestrec{0}.{1}".format(rtype, self.get_dns_domain())
+        limit = 20
+        num_recs_to_enter = limit + 5
+
+        for i in range(1, num_recs_to_enter+1):
+            ip = rec_gen(i)
+            self.make_dns_update(name, ip, rtype)
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(name,
+                                    rtype,
+                                    dns.DNS_QCLASS_IN)
+        questions.append(q)
+        self.finish_name_packet(p, questions)
+
+        (response, response_packet) =\
+            self.dns_transaction_udp(p, host=self.server_ip)
+
+        self.assertEqual(len(response.answers), num_recs_to_enter)
+
+    def test_record_limit_A(self):
+        def ip4_gen(i):
+            return "127.0.0." + str(i)
+        self.max_rec_test(rtype=dns.DNS_QTYPE_A, rec_gen=ip4_gen)
+
+    def test_record_limit_AAAA(self):
+        def ip6_gen(i):
+            return "AAAA:0:0:0:0:0:0:" + str(i)
+        self.max_rec_test(rtype=dns.DNS_QTYPE_AAAA, rec_gen=ip6_gen)
+
+    def test_record_limit_SRV(self):
+        def srv_gen(i):
+            rec = dns.srv_record()
+            rec.priority = 1
+            rec.weight = 1
+            rec.port = 92
+            rec.target = "srvtestrec" + str(i)
+            return rec
+        self.max_rec_test(rtype=dns.DNS_QTYPE_SRV, rec_gen=srv_gen)
+
+    # Same as test_record_limit_A but with a preceding CNAME follow
+    def test_cname_limit(self):
+        cname1 = "cnamelimittestrec." + self.get_dns_domain()
+        cname2 = "cnamelimittestrec2." + self.get_dns_domain()
+        cname3 = "cnamelimittestrec3." + self.get_dns_domain()
+        ip_prefix = '127.0.0.'
+        limit = 20
+        num_recs_to_enter = limit + 5
+
+        self.make_dns_update(cname1, cname2, dnsp.DNS_TYPE_CNAME)
+        self.make_dns_update(cname2, cname3, dnsp.DNS_TYPE_CNAME)
+        num_arecs_to_enter = num_recs_to_enter - 2
+        for i in range(1, num_arecs_to_enter+1):
+            ip = ip_prefix + str(i)
+            self.make_dns_update(cname3, ip, dns.DNS_QTYPE_A)
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(cname1,
+                                    dns.DNS_QTYPE_A,
+                                    dns.DNS_QCLASS_IN)
+        questions.append(q)
+        self.finish_name_packet(p, questions)
+
+        (response, response_packet) =\
+            self.dns_transaction_udp(p, host=self.server_ip)
+
+        self.assertEqual(len(response.answers), num_recs_to_enter)
+
+    # ANY query on cname record shouldn't follow the link
+    def test_cname_any_query(self):
+        cname1 = "cnameanytestrec." + self.get_dns_domain()
+        cname2 = "cnameanytestrec2." + self.get_dns_domain()
+        cname3 = "cnameanytestrec3." + self.get_dns_domain()
+
+        self.make_dns_update(cname1, cname2, dnsp.DNS_TYPE_CNAME)
+        self.make_dns_update(cname2, cname3, dnsp.DNS_TYPE_CNAME)
+
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(cname1,
+                                    dns.DNS_QTYPE_ALL,
+                                    dns.DNS_QCLASS_IN)
+        questions.append(q)
+        self.finish_name_packet(p, questions)
+
+        (response, response_packet) =\
+            self.dns_transaction_udp(p, host=self.server_ip)
+
+        self.assertEqual(len(response.answers), 1)
+        self.assertEqual(response.answers[0].name, cname1)
+        self.assertEqual(response.answers[0].rdata, cname2)
+
+
 class TestInvalidQueries(DNSTest):
     def setUp(self):
         super(TestInvalidQueries, self).setUp()