s4-join: enable cleanup on failed join
[mat/samba.git] / source4 / scripting / python / samba / join.py
index bd343fa4c550f348c0887ec613ad0027f42d7c86..38f5c8aac18ceebaebb25da8f8ac5c25367a457a 100644 (file)
@@ -34,6 +34,7 @@ from samba.dcerpc import security
 import logging
 import talloc
 import random
+import time
 
 # this makes debugging easier
 talloc.enable_null_tracking()
@@ -123,7 +124,6 @@ class dc_join(object):
         ctx.drsuapi = None
         ctx.managedby = None
         ctx.subdomain = False
-        ctx.domguid = None
 
 
     def del_noerror(ctx, dn, recursive=False):
@@ -145,11 +145,14 @@ class dc_join(object):
         try:
             # find the krbtgt link
             print("checking samaccountname")
-            res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
-                                   expression='samAccountName=%s' % ldb.binary_encode(ctx.samname),
-                                   attrs=["msDS-krbTgtLink"])
-            if res:
-                ctx.del_noerror(res[0].dn, recursive=True)
+            if ctx.subdomain:
+                res = None
+            else:
+                res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
+                                       expression='samAccountName=%s' % ldb.binary_encode(ctx.samname),
+                                       attrs=["msDS-krbTgtLink"])
+                if res:
+                    ctx.del_noerror(res[0].dn, recursive=True)
             if ctx.connection_dn is not None:
                 ctx.del_noerror(ctx.connection_dn)
             if ctx.krbtgt_dn is not None:
@@ -301,40 +304,47 @@ class dc_join(object):
         r.value_ctr = 1
 
 
-    def DsAddEntry(ctx, rec):
+    def DsAddEntry(ctx, recs):
         '''add a record via the DRSUAPI DsAddEntry call'''
         if ctx.drsuapi is None:
             ctx.drsuapi_connect()
         if ctx.tmp_samdb is None:
             ctx.create_tmp_samdb()
 
-        id = drsuapi.DsReplicaObjectIdentifier()
-        id.dn = rec['dn']
-
-        attrs = []
-        for a in rec:
-            if a == 'dn':
-                continue
-            if not isinstance(rec[a], list):
-                v = [rec[a]]
-            else:
-                v = rec[a]
-            rattr = ctx.tmp_samdb.dsdb_DsReplicaAttribute(ctx.tmp_samdb, a, v)
-            attrs.append(rattr)
-
-        attribute_ctr = drsuapi.DsReplicaAttributeCtr()
-        attribute_ctr.num_attributes = len(attrs)
-        attribute_ctr.attributes = attrs
-
-        object = drsuapi.DsReplicaObject()
-        object.identifier = id
-        object.attribute_ctr = attribute_ctr
-
-        first_object = drsuapi.DsReplicaObjectListItem()
-        first_object.object = object
+        objects = []
+        for rec in recs:
+            id = drsuapi.DsReplicaObjectIdentifier()
+            id.dn = rec['dn']
+
+            attrs = []
+            for a in rec:
+                if a == 'dn':
+                    continue
+                if not isinstance(rec[a], list):
+                    v = [rec[a]]
+                else:
+                    v = rec[a]
+                rattr = ctx.tmp_samdb.dsdb_DsReplicaAttribute(ctx.tmp_samdb, a, v)
+                attrs.append(rattr)
+
+            attribute_ctr = drsuapi.DsReplicaAttributeCtr()
+            attribute_ctr.num_attributes = len(attrs)
+            attribute_ctr.attributes = attrs
+
+            object = drsuapi.DsReplicaObject()
+            object.identifier = id
+            object.attribute_ctr = attribute_ctr
+
+            list_object = drsuapi.DsReplicaObjectListItem()
+            list_object.object = object
+            objects.append(list_object)
 
         req2 = drsuapi.DsAddEntryRequest2()
-        req2.first_object = first_object
+        req2.first_object = objects[0]
+        prev = req2.first_object
+        for o in objects[1:]:
+            prev.next_object = o
+            prev = o
 
         (level, ctr) = ctx.drsuapi.DsAddEntry(ctx.drsuapi_handle, 2, req2)
         if ctr.err_ver != 1:
@@ -343,6 +353,11 @@ class dc_join(object):
             print("DsAddEntry failed with status %s info %s" % (ctr.err_data.status,
                                                                 ctr.err_data.info.extended_err))
             raise RuntimeError("DsAddEntry failed")
+        if ctr.err_data.dir_err != drsuapi.DRSUAPI_DIRERR_OK:
+            print("DsAddEntry failed with dir_err %u" % ctr.err_data.dir_err)
+            raise RuntimeError("DsAddEntry failed")
+        return ctr.objects
+
 
     def join_add_ntdsdsa(ctx):
         '''add the ntdsdsa object'''
@@ -354,17 +369,12 @@ class dc_join(object):
             "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
             "dMDLocation" : ctx.schema_dn}
 
-        if ctx.subdomain:
-            # the local subdomain NC doesn't exist at this time
-            # so we have to add the base_dn NC later
-            nc_list = [ ctx.config_dn, ctx.schema_dn ]
-        else:
-            nc_list = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
+        nc_list = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
 
         if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
             rec["msDS-Behavior-Version"] = str(ctx.behavior_version)
 
-        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003 and not ctx.subdomain:
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
             rec["msDS-HasDomainNCs"] = ctx.base_dn
 
         if ctx.RODC:
@@ -379,32 +389,13 @@ class dc_join(object):
                 rec["msDS-HasMasterNCs"] = nc_list
             rec["options"] = "1"
             rec["invocationId"] = ndr_pack(ctx.invocation_id)
-            if ctx.subdomain:
-                ctx.samdb.add(rec, ['relax:0'])
-            else:
-                ctx.DsAddEntry(rec)
+            ctx.DsAddEntry([rec])
 
         # find the GUID of our NTDS DN
         res = ctx.samdb.search(base=ctx.ntds_dn, scope=ldb.SCOPE_BASE, attrs=["objectGUID"])
         ctx.ntds_guid = misc.GUID(ctx.samdb.schema_format_value("objectGUID", res[0]["objectGUID"][0]))
 
 
-    def join_modify_ntdsdsa(ctx):
-        '''modify the ntdsdsa object to add local partitions'''
-        print "Modifying %s using system privileges" % ctx.ntds_dn
-
-        # this works around the Enterprise Admins ACL on the NTDSDSA object
-        system_session_info = system_session()
-        ctx.samdb.set_session_info(system_session_info)
-
-        m = ldb.Message()
-        m.dn = ldb.Dn(ctx.samdb, ctx.ntds_dn)
-        m["HasMasterNCs"] = ldb.MessageElement(ctx.base_dn, ldb.FLAG_MOD_ADD, "HasMasterNCs")
-        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
-            m["msDS-HasDomainNCs"] = ldb.MessageElement(ctx.base_dn, ldb.FLAG_MOD_ADD, "msDS-HasDomainNCs")
-            m["msDS-HasMasterNCs"] = ldb.MessageElement(ctx.base_dn, ldb.FLAG_MOD_ADD, "msDS-HasMasterNCs")
-        ctx.samdb.modify(m, controls=['relax:0'])
-
     def join_add_objects(ctx):
         '''add the various objects needed for the join'''
         if ctx.acct_dn:
@@ -433,9 +424,11 @@ class dc_join(object):
         rec = {
             "dn": ctx.server_dn,
             "objectclass" : "server",
+            # windows uses 50000000 decimal for systemFlags. A windows hex/decimal mixup bug?
             "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_RENAME |
                                 samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_LIMITED_MOVE |
                                 samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
+            # windows seems to add the dnsHostName later
             "dnsHostName" : ctx.dnshostname}
 
         if ctx.acct_dn:
@@ -500,9 +493,6 @@ class dc_join(object):
     def join_add_objects2(ctx):
         '''add the various objects needed for the join, for subdomains post replication'''
 
-        if not ctx.subdomain:
-            return
-
         print "Adding %s" % ctx.partition_dn
         # NOTE: windows sends a ntSecurityDescriptor here, we
         # let it default
@@ -517,8 +507,47 @@ class dc_join(object):
             "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_CR_NTDS_NC|samba.dsdb.SYSTEM_FLAG_CR_NTDS_DOMAIN)}
         if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
             rec["msDS-Behavior-Version"] = str(ctx.behavior_version)
-        ctx.DsAddEntry(rec)
 
+        rec2 = {
+            "dn" : ctx.ntds_dn,
+            "objectclass" : "nTDSDSA",
+            "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
+            "dMDLocation" : ctx.schema_dn}
+
+        nc_list = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
+
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec2["msDS-Behavior-Version"] = str(ctx.behavior_version)
+
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec2["msDS-HasDomainNCs"] = ctx.base_dn
+
+        rec2["objectCategory"] = "CN=NTDS-DSA,%s" % ctx.schema_dn
+        rec2["HasMasterNCs"]      = nc_list
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec2["msDS-HasMasterNCs"] = nc_list
+        rec2["options"] = "1"
+        rec2["invocationId"] = ndr_pack(ctx.invocation_id)
+
+        objects = ctx.DsAddEntry([rec, rec2])
+        if len(objects) != 2:
+            raise DCJoinException("Expected 2 objects from DsAddEntry")
+
+        ctx.ntds_guid = objects[1].guid
+
+        print("Replicating partition DN")
+        ctx.repl.replicate(ctx.partition_dn,
+                           misc.GUID("00000000-0000-0000-0000-000000000000"),
+                           ctx.ntds_guid,
+                           exop=drsuapi.DRSUAPI_EXOP_REPL_OBJ,
+                           replica_flags=drsuapi.DRSUAPI_DRS_WRIT_REP)
+
+        print("Replicating NTDS DN")
+        ctx.repl.replicate(ctx.ntds_dn,
+                           misc.GUID("00000000-0000-0000-0000-000000000000"),
+                           ctx.ntds_guid,
+                           exop=drsuapi.DRSUAPI_EXOP_REPL_OBJ,
+                           replica_flags=drsuapi.DRSUAPI_DRS_WRIT_REP)
 
     def join_provision(ctx):
         '''provision the local SAM'''
@@ -536,7 +565,6 @@ class dc_join(object):
                             configdn=ctx.config_dn,
                             serverdn=ctx.server_dn, domain=ctx.domain_name,
                             hostname=ctx.myname, domainsid=ctx.domsid,
-                            domainguid=ctx.domguid,
                             machinepass=ctx.acct_pass, serverrole="domain controller",
                             sitename=ctx.site, lp=ctx.lp, ntdsguid=ctx.ntds_guid)
         print "Provision OK for domain DN %s" % presult.domaindn
@@ -558,7 +586,11 @@ class dc_join(object):
         ctx.samdb.set_invocation_id(str(ctx.invocation_id))
         ctx.local_samdb = ctx.samdb
 
-        ctx.join_add_ntdsdsa()
+        print("Finding domain GUID from ncName")
+        res = ctx.local_samdb.search(base=ctx.partition_dn, scope=ldb.SCOPE_BASE, attrs=['ncName'],
+                                     controls=["extended_dn:1:1"])
+        domguid = str(misc.GUID(ldb.Dn(ctx.samdb, res[0]['ncName'][0]).get_extended_component('GUID')))
+        print("Got domain GUID %s" % domguid)
 
         print("Calling own domain provision")
 
@@ -569,7 +601,7 @@ class dc_join(object):
 
         presult = provision_fill(ctx.local_samdb, secrets_ldb,
                                  logger, ctx.names, ctx.paths, domainsid=security.dom_sid(ctx.domsid),
-                                 domainguid=ctx.domguid,
+                                 domainguid=domguid,
                                  targetdir=ctx.targetdir, samdb_fill=FILL_SUBDOMAIN,
                                  machinepass=ctx.acct_pass, serverrole="domain controller",
                                  lp=ctx.lp, hostip=ctx.names.hostip, hostip6=ctx.names.hostip6)
@@ -584,6 +616,7 @@ class dc_join(object):
         try:
             source_dsa_invocation_id = misc.GUID(ctx.samdb.get_invocation_id())
             if ctx.ntds_guid is None:
+                print("Using DS_BIND_GUID_W2K3")
                 destination_dsa_guid = misc.GUID(drsuapi.DRSUAPI_DS_BIND_GUID_W2K3)
             else:
                 destination_dsa_guid = ctx.ntds_guid
@@ -621,6 +654,9 @@ class dc_join(object):
                 repl.replicate(ctx.new_krbtgt_dn, source_dsa_invocation_id,
                         destination_dsa_guid,
                         exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
+            ctx.repl = repl
+            ctx.source_dsa_invocation_id = source_dsa_invocation_id
+            ctx.destination_dsa_guid = destination_dsa_guid
 
             print "Committing SAM database"
         except:
@@ -708,21 +744,13 @@ class dc_join(object):
         clear_value.password = password_blob
 
         clear_authentication_information = drsblobs.AuthenticationInformation()
-        clear_authentication_information.LastUpdateTime = 0
+        clear_authentication_information.LastUpdateTime = samba.unix2nttime(int(time.time()))
         clear_authentication_information.AuthType = lsa.TRUST_AUTH_TYPE_CLEAR
         clear_authentication_information.AuthInfo = clear_value
 
-        version_value = drsblobs.AuthInfoVersion()
-        version_value.version = 1
-
-        version = drsblobs.AuthenticationInformation()
-        version.LastUpdateTime = 0
-        version.AuthType = lsa.TRUST_AUTH_TYPE_VERSION
-        version.AuthInfo = version_value
-
         authentication_information_array = drsblobs.AuthenticationInformationArray()
-        authentication_information_array.count = 2
-        authentication_information_array.array = [clear_authentication_information, version]
+        authentication_information_array.count = 1
+        authentication_information_array.array = [clear_authentication_information]
 
         outgoing = drsblobs.trustAuthInOutBlob()
         outgoing.count = 1
@@ -782,16 +810,15 @@ class dc_join(object):
         try:
             ctx.join_add_objects()
             ctx.join_provision()
-            ctx.join_add_objects2()
             ctx.join_replicate()
             if ctx.subdomain:
+                ctx.join_add_objects2()
                 ctx.join_provision_own_domain()
                 ctx.join_setup_trusts()
-                ctx.join_modify_ntdsdsa()
             ctx.join_finalise()
         except Exception:
             print "Join failed - cleaning up"
-            #ctx.cleanup_old_join()
+            ctx.cleanup_old_join()
             raise
 
 
@@ -889,10 +916,9 @@ def join_subdomain(server=None, creds=None, lp=None, site=None, netbios_name=Non
     ctx.partition_dn = "CN=%s,CN=Partitions,%s" % (ctx.domain_name, ctx.config_dn)
     ctx.base_dn = samba.dn_from_dns_name(dnsdomain)
     ctx.domsid = str(security.random_sid())
-    ctx.domguid = str(uuid.uuid4())
     ctx.acct_dn = None
     ctx.dnshostname = "%s.%s" % (ctx.myname, ctx.dnsdomain)
-    ctx.trustdom_pass = samba.generate_random_password(32, 40)
+    ctx.trustdom_pass = samba.generate_random_password(128, 128)
 
     ctx.userAccountControl = samba.dsdb.UF_SERVER_TRUST_ACCOUNT | samba.dsdb.UF_TRUSTED_FOR_DELEGATION