testprogs/blackbox: PY3 bulk change for python scripts use correct python
[amitay/samba.git] / python / samba / join.py
index 803379746e14e4ec11419581409fcd0c3be4d54d..343b1a50934f0e44848e61a6c82f0119133864a8 100644 (file)
@@ -24,13 +24,13 @@ from samba.samdb import SamDB
 from samba import gensec, Ldb, drs_utils, arcfour_encrypt, string_to_byte_array
 import ldb
 import samba
-import sys
 import uuid
 from samba.ndr import ndr_pack, ndr_unpack
 from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs, dnsserver, dnsp
 from samba.dsdb import DS_DOMAIN_FUNCTION_2003
 from samba.credentials import Credentials, DONT_USE_KERBEROS
-from samba.provision import secretsdb_self_join, provision, provision_fill, FILL_DRS, FILL_SUBDOMAIN
+from samba.provision import (secretsdb_self_join, provision, provision_fill,
+                             FILL_DRS, FILL_SUBDOMAIN, DEFAULTSITE)
 from samba.provision.common import setup_path
 from samba.schema import Schema
 from samba import descriptor
@@ -40,15 +40,16 @@ from samba import read_and_sub_file
 from samba import werror
 from base64 import b64encode
 from samba import WERRORError, NTSTATUSError
-from samba.dnsserver import ARecord, AAAARecord, PTRRecord, CNameRecord, NSRecord, MXRecord, SOARecord, SRVRecord, TXTRecord
 from samba import sd_utils
+from samba.dnsserver import ARecord, AAAARecord, CNameRecord
 import logging
-import talloc
 import random
 import time
 import re
 import os
 import tempfile
+from samba.compat import text_type
+from samba.compat import get_string
 
 
 class DCJoinException(Exception):
@@ -65,8 +66,6 @@ class DCJoinContext(object):
                  machinepass=None, use_ntvfs=False, dns_backend=None,
                  promote_existing=False, plaintext_secrets=False,
                  backend_store=None, forced_local_samdb=None):
-        if site is None:
-            site = "Default-First-Site-Name"
 
         ctx.logger = logger
         ctx.creds = creds
@@ -93,7 +92,13 @@ class DCJoinContext(object):
             ctx.samdb = forced_local_samdb
             ctx.server = ctx.samdb.url
         else:
-            if not ctx.server:
+            if ctx.server:
+                # work out the DC's site (if not already specified)
+                if site is None:
+                    ctx.site = ctx.find_dc_site(ctx.server)
+            else:
+                # work out the Primary DC for the domain (as well as an
+                # appropriate site for the new DC)
                 ctx.logger.info("Finding a writeable DC for domain '%s'" % domain)
                 ctx.server = ctx.find_dc(domain)
                 ctx.logger.info("Found DC %s" % ctx.server)
@@ -101,10 +106,13 @@ class DCJoinContext(object):
                               session_info=system_session(),
                               credentials=ctx.creds, lp=ctx.lp)
 
+        if ctx.site is None:
+            ctx.site = DEFAULTSITE
+
         try:
-            ctx.samdb.search(scope=ldb.SCOPE_ONELEVEL, attrs=["dn"])
-        except ldb.LdbError as e4:
-            (enum, estr) = e4.args
+            ctx.samdb.search(scope=ldb.SCOPE_BASE, attrs=[])
+        except ldb.LdbError as e:
+            (enum, estr) = e.args
             raise DCJoinException(estr)
 
         ctx.base_dn = str(ctx.samdb.get_default_basedn())
@@ -147,8 +155,8 @@ class DCJoinContext(object):
                 ctx.topology_dn = None
 
             ctx.SPNs = ["HOST/%s" % ctx.myname,
-                         "HOST/%s" % ctx.dnshostname,
-                         "GC/%s/%s" % (ctx.dnshostname, ctx.dnsforest)]
+                        "HOST/%s" % ctx.dnshostname,
+                        "GC/%s/%s" % (ctx.dnshostname, ctx.dnsforest)]
 
             res_rid_manager = ctx.samdb.search(scope=ldb.SCOPE_BASE,
                                                attrs=["rIDManagerReference"],
@@ -326,7 +334,8 @@ class DCJoinContext(object):
             raise Exception("Could not find domain member account '%s' to promote to a DC, use 'samba-tool domain join' instead'" % ctx.samname)
         if "msDS-krbTgtLink" in res[0] or "serverReferenceBL" in res[0] or "rIDSetReferences" in res[0]:
             raise Exception("Account '%s' appears to be an active DC, use 'samba-tool domain join' if you must re-create this account" % ctx.samname)
-        if (int(res[0]["userAccountControl"][0]) & (samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT |samba.dsdb.UF_SERVER_TRUST_ACCOUNT) == 0):
+        if (int(res[0]["userAccountControl"][0]) & (samba.dsdb.UF_WORKSTATION_TRUST_ACCOUNT |
+                                                    samba.dsdb.UF_SERVER_TRUST_ACCOUNT) == 0):
             raise Exception("Account %s is not a domain member or a bare NT4 BDC, use 'samba-tool domain join' instead'" % ctx.samname)
 
         ctx.promote_from_dn = res[0].dn
@@ -344,6 +353,14 @@ class DCJoinContext(object):
             ctx.site = ctx.cldap_ret.client_site
         return ctx.cldap_ret.pdc_dns_name
 
+    def find_dc_site(ctx, server):
+        site = None
+        cldap_ret = ctx.net.finddc(address=server,
+                                   flags=nbt.NBT_SERVER_LDAP | nbt.NBT_SERVER_DS)
+        if cldap_ret.client_site is not None and cldap_ret.client_site != "":
+            site = cldap_ret.client_site
+        return site
+
     def get_behavior_version(ctx):
         res = ctx.samdb.search(base=ctx.base_dn, scope=ldb.SCOPE_BASE, attrs=["msDS-Behavior-Version"])
         if "msDS-Behavior-Version" in res[0]:
@@ -353,21 +370,21 @@ class DCJoinContext(object):
 
     def get_dnsHostName(ctx):
         res = ctx.samdb.search(base="", scope=ldb.SCOPE_BASE, attrs=["dnsHostName"])
-        return res[0]["dnsHostName"][0]
+        return str(res[0]["dnsHostName"][0])
 
     def get_domain_name(ctx):
         '''get netbios name of the domain from the partitions record'''
         partitions_dn = ctx.samdb.get_partitions_dn()
         res = ctx.samdb.search(base=partitions_dn, scope=ldb.SCOPE_ONELEVEL, attrs=["nETBIOSName"],
                                expression='ncName=%s' % ldb.binary_encode(str(ctx.samdb.get_default_basedn())))
-        return res[0]["nETBIOSName"][0]
+        return str(res[0]["nETBIOSName"][0])
 
     def get_forest_domain_name(ctx):
         '''get netbios name of the domain from the partitions record'''
         partitions_dn = ctx.samdb.get_partitions_dn()
         res = ctx.samdb.search(base=partitions_dn, scope=ldb.SCOPE_ONELEVEL, attrs=["nETBIOSName"],
                                expression='ncName=%s' % ldb.binary_encode(str(ctx.samdb.get_root_basedn())))
-        return res[0]["nETBIOSName"][0]
+        return str(res[0]["nETBIOSName"][0])
 
     def get_parent_partition_dn(ctx):
         '''get the parent domain partition DN from parent DNS name'''
@@ -381,7 +398,7 @@ class DCJoinContext(object):
         '''get the parent domain partition DN from parent DNS name'''
         res = ctx.samdb.search(base='CN=Partitions,%s' % ctx.config_dn, attrs=['fSMORoleOwner'],
                                scope=ldb.SCOPE_BASE, controls=["extended_dn:1:1"])
-        if not 'fSMORoleOwner' in res[0]:
+        if 'fSMORoleOwner' not in res[0]:
             raise DCJoinException("Can't find naming master on partition DN %s in %s" % (ctx.partition_dn, ctx.samdb.url))
         try:
             master_guid = str(misc.GUID(ldb.Dn(ctx.samdb, res[0]['fSMORoleOwner'][0].decode('utf8')).get_extended_component('GUID')))
@@ -396,7 +413,7 @@ class DCJoinContext(object):
            so only used for RODC join'''
         res = ctx.samdb.search(base="", scope=ldb.SCOPE_BASE, attrs=["tokenGroups"])
         binsid = res[0]["tokenGroups"][0]
-        return ctx.samdb.schema_format_value("objectSID", binsid)
+        return get_string(ctx.samdb.schema_format_value("objectSID", binsid))
 
     def dn_exists(ctx, dn):
         '''check if a DN exists'''
@@ -416,7 +433,7 @@ class DCJoinContext(object):
             "dn": ctx.krbtgt_dn,
             "objectclass": "user",
             "useraccountcontrol": str(samba.dsdb.UF_NORMAL_ACCOUNT |
-                                       samba.dsdb.UF_ACCOUNTDISABLE),
+                                      samba.dsdb.UF_ACCOUNTDISABLE),
             "showinadvancedviewonly": "TRUE",
             "description": "krbtgt for %s" % ctx.samname}
         ctx.samdb.add(rec, ["rodc_join:1:1"])
@@ -482,6 +499,7 @@ class DCJoinContext(object):
                     v = [rec[a]]
                 else:
                     v = rec[a]
+                v = [x.encode('utf8') if isinstance(x, text_type) else x for x in v]
                 rattr = ctx.tmp_samdb.dsdb_DsReplicaAttribute(ctx.tmp_samdb, a, v)
                 attrs.append(rattr)
 
@@ -631,8 +649,8 @@ class DCJoinContext(object):
                 "objectclass": "server",
                 # windows uses 50000000 decimal for systemFlags. A windows hex/decimal mixup bug?
                 "systemFlags": str(samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_RENAME |
-                                    samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_LIMITED_MOVE |
-                                    samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
+                                   samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_LIMITED_MOVE |
+                                   samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
                 # windows seems to add the dnsHostName later
                 "dnsHostName": ctx.dnshostname}
 
@@ -975,7 +993,7 @@ class DCJoinContext(object):
                 repl.replicate(ctx.new_krbtgt_dn, source_dsa_invocation_id,
                                destination_dsa_guid,
                                exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
-            elif ctx.rid_manager_dn != None:
+            elif ctx.rid_manager_dn is not None:
                 # Try and get a RID Set if we can.  This is only possible against the RID Master.  Warn otherwise.
                 try:
                     repl.replicate(ctx.rid_manager_dn, source_dsa_invocation_id,
@@ -1000,6 +1018,28 @@ class DCJoinContext(object):
         else:
             ctx.local_samdb.transaction_commit()
 
+        # A large replication may have caused our LDB connection to the
+        # remote DC to timeout, so check the connection is still alive
+        ctx.refresh_ldb_connection()
+
+    def refresh_ldb_connection(ctx):
+        try:
+            # query the rootDSE to check the connection
+            ctx.samdb.search(scope=ldb.SCOPE_BASE, attrs=[])
+        except ldb.LdbError as e:
+            (enum, estr) = e.args
+
+            # if the connection was disconnected, then reconnect
+            if (enum == ldb.ERR_OPERATIONS_ERROR and
+                ('NT_STATUS_CONNECTION_DISCONNECTED' in estr or
+                 'NT_STATUS_CONNECTION_RESET' in estr)):
+                ctx.logger.warning("LDB connection disconnected. Reconnecting")
+                ctx.samdb = SamDB(url="ldap://%s" % ctx.server,
+                                  session_info=system_session(),
+                                  credentials=ctx.creds, lp=ctx.lp)
+            else:
+                raise DCJoinException(estr)
+
     def send_DsReplicaUpdateRefs(ctx, dn):
         r = drsuapi.DsReplicaUpdateRefsRequest1()
         r.naming_context = drsuapi.DsReplicaObjectIdentifier()
@@ -1042,9 +1082,8 @@ class DCJoinContext(object):
         """
 
         client_version = dnsserver.DNS_CLIENT_VERSION_LONGHORN
-        record_type = dnsp.DNS_TYPE_A
         select_flags = dnsserver.DNS_RPC_VIEW_AUTHORITY_DATA |\
-        dnsserver.DNS_RPC_VIEW_NO_CHILDREN
+            dnsserver.DNS_RPC_VIEW_NO_CHILDREN
 
         zone = ctx.dnsdomain
         msdcs_zone = "_msdcs.%s" % ctx.dnsforest
@@ -1053,7 +1092,7 @@ class DCJoinContext(object):
         cname_target = "%s.%s" % (name, zone)
         IPs = samba.interface_ips(ctx.lp, ctx.force_all_ips)
 
-        ctx.logger.info("Adding %d remote DNS records for %s.%s" % \
+        ctx.logger.info("Adding %d remote DNS records for %s.%s" %
                         (len(IPs), name, zone))
 
         binding_options = "sign"
@@ -1062,7 +1101,7 @@ class DCJoinContext(object):
 
         name_found = True
 
-        sd_helper = samba.sd_utils.SDUtils(ctx.samdb)
+        sd_helper = sd_utils.SDUtils(ctx.samdb)
 
         change_owner_sd = security.descriptor()
         change_owner_sd.owner_sid = ctx.new_dc_account_sid
@@ -1402,6 +1441,10 @@ class DCJoinContext(object):
                 print("Join failed - cleaning up")
             except IOError:
                 pass
+
+            # cleanup the failed join (checking we still have a live LDB
+            # connection to the remote DC first)
+            ctx.refresh_ldb_connection()
             ctx.cleanup_old_join()
             raise
 
@@ -1444,13 +1487,13 @@ def join_RODC(logger=None, server=None, creds=None, lp=None, site=None, netbios_
                               samba.dsdb.UF_PARTIAL_SECRETS_ACCOUNT)
 
     ctx.SPNs.extend(["RestrictedKrbHost/%s" % ctx.myname,
-                      "RestrictedKrbHost/%s" % ctx.dnshostname])
+                     "RestrictedKrbHost/%s" % ctx.dnshostname])
 
     ctx.connection_dn = "CN=RODC Connection (FRS),%s" % ctx.ntds_dn
     ctx.secure_channel_type = misc.SEC_CHAN_RODC
     ctx.RODC = True
     ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_SPECIAL_SECRET_PROCESSING |
-                           drsuapi.DRSUAPI_DRS_GET_ALL_GROUP_MEMBERSHIP)
+                          drsuapi.DRSUAPI_DRS_GET_ALL_GROUP_MEMBERSHIP)
     ctx.domain_replica_flags = ctx.replica_flags
     if domain_critical_only:
         ctx.domain_replica_flags |= drsuapi.DRSUAPI_DRS_CRITICAL_ONLY
@@ -1494,11 +1537,12 @@ def join_DC(logger=None, server=None, creds=None, lp=None, site=None, netbios_na
 
 def join_clone(logger=None, server=None, creds=None, lp=None,
                targetdir=None, domain=None, include_secrets=False,
-               dns_backend="NONE"):
+               dns_backend="NONE", backend_store=None):
     """Creates a local clone of a remote DC."""
     ctx = DCCloneContext(logger, server, creds, lp, targetdir=targetdir,
                          domain=domain, dns_backend=dns_backend,
-                         include_secrets=include_secrets)
+                         include_secrets=include_secrets,
+                         backend_store=backend_store)
 
     lp.set("workgroup", ctx.domain_name)
     logger.info("workgroup is %s" % ctx.domain_name)
@@ -1571,10 +1615,11 @@ class DCCloneContext(DCJoinContext):
 
     def __init__(ctx, logger=None, server=None, creds=None, lp=None,
                  targetdir=None, domain=None, dns_backend=None,
-                 include_secrets=False):
+                 include_secrets=False, backend_store=None):
         super(DCCloneContext, ctx).__init__(logger, server, creds, lp,
                                             targetdir=targetdir, domain=domain,
-                                            dns_backend=dns_backend)
+                                            dns_backend=dns_backend,
+                                            backend_store=backend_store)
 
         # As we don't want to create or delete these DNs, we set them to None
         ctx.server_dn = None
@@ -1624,12 +1669,13 @@ class DCCloneAndRenameContext(DCCloneContext):
 
     def __init__(ctx, new_base_dn, new_domain_name, new_realm, logger=None,
                  server=None, creds=None, lp=None, targetdir=None, domain=None,
-                 dns_backend=None, include_secrets=True):
+                 dns_backend=None, include_secrets=True, backend_store=None):
         super(DCCloneAndRenameContext, ctx).__init__(logger, server, creds, lp,
                                                      targetdir=targetdir,
                                                      domain=domain,
                                                      dns_backend=dns_backend,
-                                                     include_secrets=include_secrets)
+                                                     include_secrets=include_secrets,
+                                                     backend_store=backend_store)
         # store the new DN (etc) that we want the cloned DB to use
         ctx.new_base_dn = new_base_dn
         ctx.new_domain_name = new_domain_name
@@ -1696,7 +1742,8 @@ class DCCloneAndRenameContext(DCCloneContext):
                             configdn=ctx.rename_dn(ctx.config_dn),
                             domain=ctx.new_domain_name, domainsid=ctx.domsid,
                             serverrole="active directory domain controller",
-                            dns_backend=ctx.dns_backend)
+                            dns_backend=ctx.dns_backend,
+                            backend_store=ctx.backend_store)
 
         print("Provision OK for renamed domain DN %s" % presult.domaindn)
         ctx.local_samdb = presult.samdb