testprogs/blackbox: PY3 bulk change for python scripts use correct python
[amitay/samba.git] / python / samba / join.py
index 68b2bfc1c68e10e936e59b0e8615b2c32ca64ccc..343b1a50934f0e44848e61a6c82f0119133864a8 100644 (file)
@@ -22,12 +22,15 @@ from __future__ import print_function
 from samba.auth import system_session
 from samba.samdb import SamDB
 from samba import gensec, Ldb, drs_utils, arcfour_encrypt, string_to_byte_array
-import ldb, samba, sys, uuid
+import ldb
+import samba
+import uuid
 from samba.ndr import ndr_pack, ndr_unpack
 from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs, dnsserver, dnsp
 from samba.dsdb import DS_DOMAIN_FUNCTION_2003
 from samba.credentials import Credentials, DONT_USE_KERBEROS
-from samba.provision import secretsdb_self_join, provision, provision_fill, FILL_DRS, FILL_SUBDOMAIN
+from samba.provision import (secretsdb_self_join, provision, provision_fill,
+                             FILL_DRS, FILL_SUBDOMAIN, DEFAULTSITE)
 from samba.provision.common import setup_path
 from samba.schema import Schema
 from samba import descriptor
@@ -37,15 +40,17 @@ from samba import read_and_sub_file
 from samba import werror
 from base64 import b64encode
 from samba import WERRORError, NTSTATUSError
-from samba.dnsserver import ARecord, AAAARecord, PTRRecord, CNameRecord, NSRecord, MXRecord, SOARecord, SRVRecord, TXTRecord
 from samba import sd_utils
+from samba.dnsserver import ARecord, AAAARecord, CNameRecord
 import logging
-import talloc
 import random
 import time
 import re
 import os
 import tempfile
+from samba.compat import text_type
+from samba.compat import get_string
+
 
 class DCJoinException(Exception):
 
@@ -61,8 +66,6 @@ class DCJoinContext(object):
                  machinepass=None, use_ntvfs=False, dns_backend=None,
                  promote_existing=False, plaintext_secrets=False,
                  backend_store=None, forced_local_samdb=None):
-        if site is None:
-            site = "Default-First-Site-Name"
 
         ctx.logger = logger
         ctx.creds = creds
@@ -89,7 +92,13 @@ class DCJoinContext(object):
             ctx.samdb = forced_local_samdb
             ctx.server = ctx.samdb.url
         else:
-            if not ctx.server:
+            if ctx.server:
+                # work out the DC's site (if not already specified)
+                if site is None:
+                    ctx.site = ctx.find_dc_site(ctx.server)
+            else:
+                # work out the Primary DC for the domain (as well as an
+                # appropriate site for the new DC)
                 ctx.logger.info("Finding a writeable DC for domain '%s'" % domain)
                 ctx.server = ctx.find_dc(domain)
                 ctx.logger.info("Found DC %s" % ctx.server)
@@ -97,13 +106,15 @@ class DCJoinContext(object):
                               session_info=system_session(),
                               credentials=ctx.creds, lp=ctx.lp)
 
+        if ctx.site is None:
+            ctx.site = DEFAULTSITE
+
         try:
-            ctx.samdb.search(scope=ldb.SCOPE_ONELEVEL, attrs=["dn"])
-        except ldb.LdbError as e4:
-            (enum, estr) = e4.args
+            ctx.samdb.search(scope=ldb.SCOPE_BASE, attrs=[])
+        except ldb.LdbError as e:
+            (enum, estr) = e.args
             raise DCJoinException(estr)
 
-
         ctx.base_dn = str(ctx.samdb.get_default_basedn())
         ctx.root_dn = str(ctx.samdb.get_root_basedn())
         ctx.schema_dn = str(ctx.samdb.get_schema_basedn())
@@ -144,8 +155,8 @@ class DCJoinContext(object):
                 ctx.topology_dn = None
 
             ctx.SPNs = ["HOST/%s" % ctx.myname,
-                         "HOST/%s" % ctx.dnshostname,
-                         "GC/%s/%s" % (ctx.dnshostname, ctx.dnsforest)]
+                        "HOST/%s" % ctx.dnshostname,
+                        "GC/%s/%s" % (ctx.dnshostname, ctx.dnsforest)]
 
             res_rid_manager = ctx.samdb.search(scope=ldb.SCOPE_BASE,
                                                attrs=["rIDManagerReference"],
@@ -264,7 +275,6 @@ class DCJoinContext(object):
                                   (ldb.binary_encode("dns-%s" % ctx.myname),
                                    ldb.binary_encode("dns/%s" % ctx.dnshostname)))
 
-
     def cleanup_old_join(ctx, force=False):
         """Remove any DNs from a previous join."""
         # find the krbtgt link
@@ -311,8 +321,6 @@ class DCJoinContext(object):
         if ctx.dns_cname_dn:
             ctx.del_noerror(ctx.dns_cname_dn)
 
-
-
     def promote_possible(ctx):
         """confirm that the account is just a bare NT4 BDC or a member server, so can be safely promoted"""
         if ctx.subdomain:
@@ -326,12 +334,12 @@ class DCJoinContext(object):
             raise Exception("Could not find domain member account '%s' to promote to a DC, use 'samba-tool domain join' instead'" % ctx.samname)
         if "msDS-krbTgtLink" in res[0] or "serverReferenceBL" in res[0] or "rIDSetReferences" in res[0]:
             raise Exception("Account '%s' appears to be an active DC, use 'samba-tool domain join' if you must re-create this account" % ctx.samname)
-        if (int(res[0]["userAccountControl"][0]) & (samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT|samba.dsdb.UF_SERVER_TRUST_ACCOUNT) == 0):
+        if (int(res[0]["userAccountControl"][0]) & (samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT |
+                                                    samba.dsdb.UF_SERVER_TRUST_ACCOUNT) == 0):
             raise Exception("Account %s is not a domain member or a bare NT4 BDC, use 'samba-tool domain join' instead'" % ctx.samname)
 
         ctx.promote_from_dn = res[0].dn
 
-
     def find_dc(ctx, domain):
         """find a writeable DC for the given domain"""
         try:
@@ -345,6 +353,13 @@ class DCJoinContext(object):
             ctx.site = ctx.cldap_ret.client_site
         return ctx.cldap_ret.pdc_dns_name
 
+    def find_dc_site(ctx, server):
+        site = None
+        cldap_ret = ctx.net.finddc(address=server,
+                                   flags=nbt.NBT_SERVER_LDAP | nbt.NBT_SERVER_DS)
+        if cldap_ret.client_site is not None and cldap_ret.client_site != "":
+            site = cldap_ret.client_site
+        return site
 
     def get_behavior_version(ctx):
         res = ctx.samdb.search(base=ctx.base_dn, scope=ldb.SCOPE_BASE, attrs=["msDS-Behavior-Version"])
@@ -355,21 +370,21 @@ class DCJoinContext(object):
 
     def get_dnsHostName(ctx):
         res = ctx.samdb.search(base="", scope=ldb.SCOPE_BASE, attrs=["dnsHostName"])
-        return res[0]["dnsHostName"][0]
+        return str(res[0]["dnsHostName"][0])
 
     def get_domain_name(ctx):
         '''get netbios name of the domain from the partitions record'''
         partitions_dn = ctx.samdb.get_partitions_dn()
         res = ctx.samdb.search(base=partitions_dn, scope=ldb.SCOPE_ONELEVEL, attrs=["nETBIOSName"],
                                expression='ncName=%s' % ldb.binary_encode(str(ctx.samdb.get_default_basedn())))
-        return res[0]["nETBIOSName"][0]
+        return str(res[0]["nETBIOSName"][0])
 
     def get_forest_domain_name(ctx):
         '''get netbios name of the domain from the partitions record'''
         partitions_dn = ctx.samdb.get_partitions_dn()
         res = ctx.samdb.search(base=partitions_dn, scope=ldb.SCOPE_ONELEVEL, attrs=["nETBIOSName"],
                                expression='ncName=%s' % ldb.binary_encode(str(ctx.samdb.get_root_basedn())))
-        return res[0]["nETBIOSName"][0]
+        return str(res[0]["nETBIOSName"][0])
 
     def get_parent_partition_dn(ctx):
         '''get the parent domain partition DN from parent DNS name'''
@@ -383,7 +398,7 @@ class DCJoinContext(object):
         '''get the parent domain partition DN from parent DNS name'''
         res = ctx.samdb.search(base='CN=Partitions,%s' % ctx.config_dn, attrs=['fSMORoleOwner'],
                                scope=ldb.SCOPE_BASE, controls=["extended_dn:1:1"])
-        if not 'fSMORoleOwner' in res[0]:
+        if 'fSMORoleOwner' not in res[0]:
             raise DCJoinException("Can't find naming master on partition DN %s in %s" % (ctx.partition_dn, ctx.samdb.url))
         try:
             master_guid = str(misc.GUID(ldb.Dn(ctx.samdb, res[0]['fSMORoleOwner'][0].decode('utf8')).get_extended_component('GUID')))
@@ -398,7 +413,7 @@ class DCJoinContext(object):
            so only used for RODC join'''
         res = ctx.samdb.search(base="", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
         binsid = res[0]["tokenGroups"][0]
-        return ctx.samdb.schema_format_value("objectSID", binsid)
+        return get_string(ctx.samdb.schema_format_value("objectSID", binsid))
 
     def dn_exists(ctx, dn):
         '''check if a DN exists'''
@@ -418,7 +433,7 @@ class DCJoinContext(object):
             "dn": ctx.krbtgt_dn,
             "objectclass": "user",
             "useraccountcontrol": str(samba.dsdb.UF_NORMAL_ACCOUNT |
-                                       samba.dsdb.UF_ACCOUNTDISABLE),
+                                      samba.dsdb.UF_ACCOUNTDISABLE),
             "showinadvancedviewonly": "TRUE",
             "description": "krbtgt for %s" % ctx.samname}
         ctx.samdb.add(rec, ["rodc_join:1:1"])
@@ -464,7 +479,6 @@ class DCJoinContext(object):
         r.attid = ctx.tmp_samdb.get_attid_from_lDAPDisplayName(attrname)
         r.value_ctr = 1
 
-
     def DsAddEntry(ctx, recs):
         '''add a record via the DRSUAPI DsAddEntry call'''
         if ctx.drsuapi is None:
@@ -485,6 +499,7 @@ class DCJoinContext(object):
                     v = [rec[a]]
                 else:
                     v = rec[a]
+                v = [x.encode('utf8') if isinstance(x, text_type) else x for x in v]
                 rattr = ctx.tmp_samdb.dsdb_DsReplicaAttribute(ctx.tmp_samdb, a, v)
                 attrs.append(rattr)
 
@@ -634,8 +649,8 @@ class DCJoinContext(object):
                 "objectclass": "server",
                 # windows uses 50000000 decimal for systemFlags. A windows hex/decimal mixup bug?
                 "systemFlags": str(samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_RENAME |
-                                    samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_LIMITED_MOVE |
-                                    samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
+                                   samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_LIMITED_MOVE |
+                                   samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
                 # windows seems to add the dnsHostName later
                 "dnsHostName": ctx.dnshostname}
 
@@ -810,7 +825,7 @@ class DCJoinContext(object):
             "nETBIOSName": ctx.domain_name,
             "dnsRoot": ctx.dnsdomain,
             "trustParent": ctx.parent_partition_dn,
-            "systemFlags": str(samba.dsdb.SYSTEM_FLAG_CR_NTDS_NC|samba.dsdb.SYSTEM_FLAG_CR_NTDS_DOMAIN),
+            "systemFlags": str(samba.dsdb.SYSTEM_FLAG_CR_NTDS_NC |samba.dsdb.SYSTEM_FLAG_CR_NTDS_DOMAIN),
             "ntSecurityDescriptor": sd_binary,
         }
 
@@ -978,7 +993,7 @@ class DCJoinContext(object):
                 repl.replicate(ctx.new_krbtgt_dn, source_dsa_invocation_id,
                                destination_dsa_guid,
                                exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
-            elif ctx.rid_manager_dn != None:
+            elif ctx.rid_manager_dn is not None:
                 # Try and get a RID Set if we can.  This is only possible against the RID Master.  Warn otherwise.
                 try:
                     repl.replicate(ctx.rid_manager_dn, source_dsa_invocation_id,
@@ -1003,6 +1018,28 @@ class DCJoinContext(object):
         else:
             ctx.local_samdb.transaction_commit()
 
+        # A large replication may have caused our LDB connection to the
+        # remote DC to timeout, so check the connection is still alive
+        ctx.refresh_ldb_connection()
+
+    def refresh_ldb_connection(ctx):
+        try:
+            # query the rootDSE to check the connection
+            ctx.samdb.search(scope=ldb.SCOPE_BASE, attrs=[])
+        except ldb.LdbError as e:
+            (enum, estr) = e.args
+
+            # if the connection was disconnected, then reconnect
+            if (enum == ldb.ERR_OPERATIONS_ERROR and
+                ('NT_STATUS_CONNECTION_DISCONNECTED' in estr or
+                 'NT_STATUS_CONNECTION_RESET' in estr)):
+                ctx.logger.warning("LDB connection disconnected. Reconnecting")
+                ctx.samdb = SamDB(url="ldap://%s" % ctx.server,
+                                  session_info=system_session(),
+                                  credentials=ctx.creds, lp=ctx.lp)
+            else:
+                raise DCJoinException(estr)
+
     def send_DsReplicaUpdateRefs(ctx, dn):
         r = drsuapi.DsReplicaUpdateRefsRequest1()
         r.naming_context = drsuapi.DsReplicaObjectIdentifier()
@@ -1045,9 +1082,8 @@ class DCJoinContext(object):
         """
 
         client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
-        record_type = dnsp.DNS_TYPE_A
         select_flags = dnsserver.DNS_RPC_VIEW_AUTHORITY_DATA |\
-        dnsserver.DNS_RPC_VIEW_NO_CHILDREN
+            dnsserver.DNS_RPC_VIEW_NO_CHILDREN
 
         zone = ctx.dnsdomain
         msdcs_zone = "_msdcs.%s" % ctx.dnsforest
@@ -1056,17 +1092,16 @@ class DCJoinContext(object):
         cname_target = "%s.%s" % (name, zone)
         IPs = samba.interface_ips(ctx.lp, ctx.force_all_ips)
 
-        ctx.logger.info("Adding %d remote DNS records for %s.%s" % \
+        ctx.logger.info("Adding %d remote DNS records for %s.%s" %
                         (len(IPs), name, zone))
 
         binding_options = "sign"
         dns_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options),
                                        ctx.lp, ctx.creds)
 
-
         name_found = True
 
-        sd_helper = samba.sd_utils.SDUtils(ctx.samdb)
+        sd_helper = sd_utils.SDUtils(ctx.samdb)
 
         change_owner_sd = security.descriptor()
         change_owner_sd.owner_sid = ctx.new_dc_account_sid
@@ -1147,7 +1182,6 @@ class DCJoinContext(object):
                                                 % (security.SECINFO_OWNER
                                                    | security.SECINFO_GROUP)])
 
-
             # Add record
             ctx.logger.info("Adding DNS CNAME record %s.%s for %s"
                             % (msdcs_cname, msdcs_zone, cname_target))
@@ -1177,7 +1211,6 @@ class DCJoinContext(object):
         ctx.logger.info("All other DNS records (like _ldap SRV records) " +
                         "will be created samba_dnsupdate on first startup")
 
-
     def join_replicate_new_dns_records(ctx):
         for nc in (ctx.domaindns_zone, ctx.forestdns_zone):
             if nc in ctx.nc_list:
@@ -1187,8 +1220,6 @@ class DCJoinContext(object):
                                    replica_flags=ctx.replica_flags,
                                    full_sync=False)
 
-
-
     def join_finalise(ctx):
         """Finalise the join, mark us synchronised and setup secrets db."""
 
@@ -1362,7 +1393,6 @@ class DCJoinContext(object):
         }
         ctx.local_samdb.add(rec)
 
-
     def build_nc_lists(ctx):
         # nc_list is the list of naming context (NC) for which we will
         # replicate in and send a updateRef command to the partner DC
@@ -1411,6 +1441,10 @@ class DCJoinContext(object):
                 print("Join failed - cleaning up")
             except IOError:
                 pass
+
+            # cleanup the failed join (checking we still have a live LDB
+            # connection to the remote DC first)
+            ctx.refresh_ldb_connection()
             ctx.cleanup_old_join()
             raise
 
@@ -1453,13 +1487,13 @@ def join_RODC(logger=None, server=None, creds=None, lp=None, site=None, netbios_
                               samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT)
 
     ctx.SPNs.extend(["RestrictedKrbHost/%s" % ctx.myname,
-                      "RestrictedKrbHost/%s" % ctx.dnshostname])
+                     "RestrictedKrbHost/%s" % ctx.dnshostname])
 
     ctx.connection_dn = "CN=RODC Connection (FRS),%s" % ctx.ntds_dn
     ctx.secure_channel_type = misc.SEC_CHAN_RODC
     ctx.RODC = True
     ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_SPECIAL_SECRET_PROCESSING |
-                           drsuapi.DRSUAPI_DRS_GET_ALL_GROUP_MEMBERSHIP)
+                          drsuapi.DRSUAPI_DRS_GET_ALL_GROUP_MEMBERSHIP)
     ctx.domain_replica_flags = ctx.replica_flags
     if domain_critical_only:
         ctx.domain_replica_flags |= drsuapi.DRSUAPI_DRS_CRITICAL_ONLY
@@ -1500,13 +1534,15 @@ def join_DC(logger=None, server=None, creds=None, lp=None, site=None, netbios_na
     ctx.do_join()
     logger.info("Joined domain %s (SID %s) as a DC" % (ctx.domain_name, ctx.domsid))
 
+
 def join_clone(logger=None, server=None, creds=None, lp=None,
                targetdir=None, domain=None, include_secrets=False,
-               dns_backend="NONE"):
+               dns_backend="NONE", backend_store=None):
     """Creates a local clone of a remote DC."""
     ctx = DCCloneContext(logger, server, creds, lp, targetdir=targetdir,
                          domain=domain, dns_backend=dns_backend,
-                         include_secrets=include_secrets)
+                         include_secrets=include_secrets,
+                         backend_store=backend_store)
 
     lp.set("workgroup", ctx.domain_name)
     logger.info("workgroup is %s" % ctx.domain_name)
@@ -1518,6 +1554,7 @@ def join_clone(logger=None, server=None, creds=None, lp=None,
     logger.info("Cloned domain %s (SID %s)" % (ctx.domain_name, ctx.domsid))
     return ctx
 
+
 def join_subdomain(logger=None, server=None, creds=None, lp=None, site=None,
                    netbios_name=None, targetdir=None, parent_domain=None, dnsdomain=None,
                    netbios_domain=None, machinepass=None, adminpass=None, use_ntvfs=False,
@@ -1578,10 +1615,11 @@ class DCCloneContext(DCJoinContext):
 
     def __init__(ctx, logger=None, server=None, creds=None, lp=None,
                  targetdir=None, domain=None, dns_backend=None,
-                 include_secrets=False):
+                 include_secrets=False, backend_store=None):
         super(DCCloneContext, ctx).__init__(logger, server, creds, lp,
                                             targetdir=targetdir, domain=domain,
-                                            dns_backend=dns_backend)
+                                            dns_backend=dns_backend,
+                                            backend_store=backend_store)
 
         # As we don't want to create or delete these DNs, we set them to None
         ctx.server_dn = None
@@ -1631,12 +1669,13 @@ class DCCloneAndRenameContext(DCCloneContext):
 
     def __init__(ctx, new_base_dn, new_domain_name, new_realm, logger=None,
                  server=None, creds=None, lp=None, targetdir=None, domain=None,
-                 dns_backend=None, include_secrets=True):
+                 dns_backend=None, include_secrets=True, backend_store=None):
         super(DCCloneAndRenameContext, ctx).__init__(logger, server, creds, lp,
                                                      targetdir=targetdir,
                                                      domain=domain,
                                                      dns_backend=dns_backend,
-                                                     include_secrets=include_secrets)
+                                                     include_secrets=include_secrets,
+                                                     backend_store=backend_store)
         # store the new DN (etc) that we want the cloned DB to use
         ctx.new_base_dn = new_base_dn
         ctx.new_domain_name = new_domain_name
@@ -1703,7 +1742,8 @@ class DCCloneAndRenameContext(DCCloneContext):
                             configdn=ctx.rename_dn(ctx.config_dn),
                             domain=ctx.new_domain_name, domainsid=ctx.domsid,
                             serverrole="active directory domain controller",
-                            dns_backend=ctx.dns_backend)
+                            dns_backend=ctx.dns_backend,
+                            backend_store=ctx.backend_store)
 
         print("Provision OK for renamed domain DN %s" % presult.domaindn)
         ctx.local_samdb = presult.samdb