s4 dns: Fix TCP handling in the DNS server
[samba.git] / source4 / scripting / python / samba / tests / dns.py
index ca9edbf50009c7e5dca4b9eb60774985d15617ca..26f80898225d80d683c931c489750a98534b7797 100644 (file)
@@ -99,6 +99,22 @@ class DNSTest(TestCase):
             if s is not None:
                 s.close()
 
+    def dns_transaction_tcp(self, packet, host=os.getenv('DC_SERVER_IP')):
+        "send a DNS query and read the reply"
+        s = None
+        try:
+            send_packet = ndr.ndr_pack(packet)
+            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            s.connect((host, 53))
+            tcp_packet = struct.pack('!H', len(send_packet))
+            tcp_packet += send_packet
+            s.send(tcp_packet, 0)
+            recv_packet = s.recv(0xffff + 2, 0)
+            return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+        finally:
+                if s is not None:
+                    s.close()
+
     def test_one_a_query(self):
         "create a query packet containing one query record"
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
@@ -117,6 +133,24 @@ class DNSTest(TestCase):
         self.assertEquals(response.answers[0].rdata,
                           os.getenv('DC_SERVER_IP'))
 
+    def test_one_a_query_tcp(self):
+        "create a query packet containing one query record via TCP"
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        name = "%s.%s" % (os.getenv('DC_SERVER'), self.get_dns_domain())
+        q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
+        print "asking for ", q.name
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        response = self.dns_transaction_tcp(p)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata,
+                          os.getenv('DC_SERVER_IP'))
+
     def test_two_queries(self):
         "create a query packet containing two query records"
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)