1f9fab1d72a61da942515480e40a90d932ac00ed
[nivanova/samba-autobuild/.git] / python / samba / tests / join.py
1 # Test joining as a DC and check the join was done right
2 #
3 # Copyright (C) Andrew Bartlett <abartlet@samba.org> 2017
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 import samba
20 import sys
21 import shutil
22 import os
23 from samba.tests.dns_base import DNSTKeyTest
24 from samba.join import dc_join
25 from samba.dcerpc import drsuapi, misc, dns
26 from samba.credentials import Credentials
27
28 def get_logger(name="subunit"):
29     """Get a logger object."""
30     import logging
31     logger = logging.getLogger(name)
32     logger.addHandler(logging.StreamHandler(sys.stderr))
33     return logger
34
35 class JoinTestCase(DNSTKeyTest):
36     def setUp(self):
37         self.server = samba.tests.env_get_var_value("SERVER")
38         self.server_ip = samba.tests.env_get_var_value("SERVER_IP")
39         super(JoinTestCase, self).setUp()
40         self.lp = samba.tests.env_loadparm()
41         self.creds = self.get_credentials()
42         self.netbios_name = "jointest1"
43         logger = get_logger()
44
45         self.join_ctx = dc_join(server=self.server, creds=self.creds, lp=self.get_loadparm(),
46                                 netbios_name=self.netbios_name,
47                                 targetdir=self.tempdir,
48                                 domain=None, logger=logger,
49                                 dns_backend="SAMBA_INTERNAL")
50         self.join_ctx.userAccountControl = (samba.dsdb.UF_SERVER_TRUST_ACCOUNT |
51                                             samba.dsdb.UF_TRUSTED_FOR_DELEGATION)
52
53         self.join_ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_WRIT_REP |
54                                         drsuapi.DRSUAPI_DRS_FULL_SYNC_IN_PROGRESS)
55         self.join_ctx.domain_replica_flags = self.join_ctx.replica_flags
56         self.join_ctx.secure_channel_type = misc.SEC_CHAN_BDC
57
58         self.join_ctx.cleanup_old_join()
59
60         self.join_ctx.force_all_ips = True
61
62         self.join_ctx.do_join()
63
64     def tearDown(self):
65         try:
66             paths = self.join_ctx.paths
67         except AttributeError:
68             paths = None
69
70         if paths is not None:
71             shutil.rmtree(paths.private_dir)
72             shutil.rmtree(paths.state_dir)
73             shutil.rmtree(os.path.join(self.tempdir, "etc"))
74             shutil.rmtree(os.path.join(self.tempdir, "msg.lock"))
75             os.unlink(os.path.join(self.tempdir, "names.tdb"))
76
77         self.join_ctx.cleanup_old_join(force=True)
78
79         super(JoinTestCase, self).tearDown()
80
81
82     def test_join_makes_records(self):
83
84         "create a query packet containing one query record via TCP"
85         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
86         questions = []
87
88         name = self.join_ctx.dnshostname
89         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
90         questions.append(q)
91
92         # Get expected IPs
93         IPs = samba.interface_ips(self.lp)
94
95         self.finish_name_packet(p, questions)
96         (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
97         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
98         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
99         self.assertEquals(response.ancount, len(IPs))
100
101         questions = []
102         name = "%s._msdcs.%s" % (self.join_ctx.ntds_guid, self.join_ctx.dnsforest)
103         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
104         questions.append(q)
105
106         self.finish_name_packet(p, questions)
107         (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
108         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
109         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
110
111         self.assertEquals(response.ancount, 1 + len(IPs))
112         self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME)
113         self.assertEquals(response.answers[0].rdata, self.join_ctx.dnshostname)
114         self.assertEquals(response.answers[1].rr_type, dns.DNS_QTYPE_A)
115
116
117     def test_join_records_can_update(self):
118         dc_creds = Credentials()
119         dc_creds.guess(self.join_ctx.lp)
120         dc_creds.set_machine_account(self.join_ctx.lp)
121
122         self.tkey_trans(creds=dc_creds)
123
124         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
125         q = self.make_name_question(self.join_ctx.dnsdomain,
126                                     dns.DNS_QTYPE_SOA,
127                                     dns.DNS_QCLASS_IN)
128         questions = []
129         questions.append(q)
130         self.finish_name_packet(p, questions)
131
132         updates = []
133         # Delete the old expected IPs
134         IPs = samba.interface_ips(self.lp)
135         for IP in IPs[1:]:
136             if ":" in IP:
137                 r = dns.res_rec()
138                 r.name = self.join_ctx.dnshostname
139                 r.rr_type = dns.DNS_QTYPE_AAAA
140                 r.rr_class = dns.DNS_QCLASS_NONE
141                 r.ttl = 0
142                 r.length = 0xffff
143                 rdata = IP
144             else:
145                 r = dns.res_rec()
146                 r.name = self.join_ctx.dnshostname
147                 r.rr_type = dns.DNS_QTYPE_A
148                 r.rr_class = dns.DNS_QCLASS_NONE
149                 r.ttl = 0
150                 r.length = 0xffff
151                 rdata = IP
152
153             r.rdata = rdata
154             updates.append(r)
155
156         p.nscount = len(updates)
157         p.nsrecs = updates
158
159         mac = self.sign_packet(p, self.key_name)
160         (response, response_p) = self.dns_transaction_udp(p, self.server_ip)
161         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
162         self.verify_packet(response, response_p, mac)
163
164         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
165         questions = []
166
167         name = self.join_ctx.dnshostname
168         q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN)
169         questions.append(q)
170
171         self.finish_name_packet(p, questions)
172         (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip)
173         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
174         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
175         self.assertEquals(response.ancount, 1)