s4-provision Add initial support for joining as a new subdomain
[amitay/samba.git] / source4 / scripting / python / samba / join.py
index 2967ddf7e78ff3c40ca293e3665491c77a2df60a..195dfc23120f2c81314a4d3f49add3c960d9f44f 100644 (file)
@@ -25,7 +25,7 @@ from samba.samdb import SamDB
 from samba import gensec, Ldb, drs_utils
 import ldb, samba, sys, os, uuid
 from samba.ndr import ndr_pack
-from samba.dcerpc import security, drsuapi, misc, nbt
+from samba.dcerpc import security, drsuapi, misc, nbt, lsa, drsblobs
 from samba.credentials import Credentials, DONT_USE_KERBEROS
 from samba.provision import secretsdb_self_join, provision, provision_fill, FILL_DRS, FILL_SUBDOMAIN
 from samba.schema import Schema
@@ -33,6 +33,8 @@ from samba.net import Net
 from samba.dcerpc import security
 import logging
 import talloc
+import random
+import time
 
 # this makes debugging easier
 talloc.enable_null_tracking()
@@ -82,6 +84,7 @@ class dc_join(object):
         ctx.config_dn = str(ctx.samdb.get_config_basedn())
         ctx.domsid = ctx.samdb.get_domain_sid()
         ctx.domain_name = ctx.get_domain_name()
+        ctx.invocation_id = misc.GUID(str(uuid.uuid4()))
 
         ctx.dc_ntds_dn = ctx.get_dsServiceName()
         ctx.dc_dnsHostName = ctx.get_dnsHostName()
@@ -120,6 +123,7 @@ class dc_join(object):
         ctx.krbtgt_dn = None
         ctx.drsuapi = None
         ctx.managedby = None
+        ctx.subdomain = False
 
 
     def del_noerror(ctx, dn, recursive=False):
@@ -141,11 +145,14 @@ class dc_join(object):
         try:
             # find the krbtgt link
             print("checking samaccountname")
-            res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
-                                   expression='samAccountName=%s' % ldb.binary_encode(ctx.samname),
-                                   attrs=["msDS-krbTgtLink"])
-            if res:
-                ctx.del_noerror(res[0].dn, recursive=True)
+            if ctx.subdomain:
+                res = None
+            else:
+                res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
+                                       expression='samAccountName=%s' % ldb.binary_encode(ctx.samname),
+                                       attrs=["msDS-krbTgtLink"])
+                if res:
+                    ctx.del_noerror(res[0].dn, recursive=True)
             if ctx.connection_dn is not None:
                 ctx.del_noerror(ctx.connection_dn)
             if ctx.krbtgt_dn is not None:
@@ -154,9 +161,35 @@ class dc_join(object):
             ctx.del_noerror(ctx.server_dn, recursive=True)
             if ctx.topology_dn:
                 ctx.del_noerror(ctx.topology_dn)
+            if ctx.partition_dn:
+                ctx.del_noerror(ctx.partition_dn)
             if res:
                 ctx.new_krbtgt_dn = res[0]["msDS-Krbtgtlink"][0]
                 ctx.del_noerror(ctx.new_krbtgt_dn)
+
+            if ctx.subdomain:
+                binding_options = "sign"
+                lsaconn = lsa.lsarpc("ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options),
+                                     ctx.lp, ctx.creds)
+
+                objectAttr = lsa.ObjectAttribute()
+                objectAttr.sec_qos = lsa.QosInfo()
+
+                pol_handle = lsaconn.OpenPolicy2(''.decode('utf-8'),
+                                                 objectAttr, security.SEC_FLAG_MAXIMUM_ALLOWED)
+
+                name = lsa.String()
+                name.string = ctx.realm
+                info = lsaconn.QueryTrustedDomainInfoByName(pol_handle, name, lsa.LSA_TRUSTED_DOMAIN_INFO_FULL_INFO)
+
+                lsaconn.DeleteTrustedDomain(pol_handle, info.info_ex.sid)
+
+                name = lsa.String()
+                name.string = ctx.domain_name
+                info = lsaconn.QueryTrustedDomainInfoByName(pol_handle, name, lsa.LSA_TRUSTED_DOMAIN_INFO_FULL_INFO)
+
+                lsaconn.DeleteTrustedDomain(pol_handle, info.info_ex.sid)
+
         except Exception:
             pass
 
@@ -193,6 +226,13 @@ class dc_join(object):
                                expression='ncName=%s' % ctx.samdb.get_default_basedn())
         return res[0]["nETBIOSName"][0]
 
+    def get_parent_partition_dn(ctx):
+        '''get the parent domain partition DN from parent DNS name'''
+        res = ctx.samdb.search(base=ctx.config_dn, attrs=[],
+                               expression='(&(objectclass=crossRef)(dnsRoot=%s)(systemFlags:%s:=%u))' %
+                               (ctx.parent_dnsdomain, ldb.OID_COMPARATOR_AND, samba.dsdb.SYSTEM_FLAG_CR_NTDS_DOMAIN))
+        return str(res[0].dn)
+
     def get_mysid(ctx):
         '''get the SID of the connected user. Only works with w2k8 and later,
            so only used for RODC join'''
@@ -264,40 +304,47 @@ class dc_join(object):
         r.value_ctr = 1
 
 
-    def DsAddEntry(ctx, rec):
+    def DsAddEntry(ctx, recs):
         '''add a record via the DRSUAPI DsAddEntry call'''
         if ctx.drsuapi is None:
             ctx.drsuapi_connect()
         if ctx.tmp_samdb is None:
             ctx.create_tmp_samdb()
 
-        id = drsuapi.DsReplicaObjectIdentifier()
-        id.dn = rec['dn']
-
-        attrs = []
-        for a in rec:
-            if a == 'dn':
-                continue
-            if not isinstance(rec[a], list):
-                v = [rec[a]]
-            else:
-                v = rec[a]
-            rattr = ctx.tmp_samdb.dsdb_DsReplicaAttribute(ctx.tmp_samdb, a, v)
-            attrs.append(rattr)
-
-        attribute_ctr = drsuapi.DsReplicaAttributeCtr()
-        attribute_ctr.num_attributes = len(attrs)
-        attribute_ctr.attributes = attrs
-
-        object = drsuapi.DsReplicaObject()
-        object.identifier = id
-        object.attribute_ctr = attribute_ctr
-
-        first_object = drsuapi.DsReplicaObjectListItem()
-        first_object.object = object
+        objects = []
+        for rec in recs:
+            id = drsuapi.DsReplicaObjectIdentifier()
+            id.dn = rec['dn']
+
+            attrs = []
+            for a in rec:
+                if a == 'dn':
+                    continue
+                if not isinstance(rec[a], list):
+                    v = [rec[a]]
+                else:
+                    v = rec[a]
+                rattr = ctx.tmp_samdb.dsdb_DsReplicaAttribute(ctx.tmp_samdb, a, v)
+                attrs.append(rattr)
+
+            attribute_ctr = drsuapi.DsReplicaAttributeCtr()
+            attribute_ctr.num_attributes = len(attrs)
+            attribute_ctr.attributes = attrs
+
+            object = drsuapi.DsReplicaObject()
+            object.identifier = id
+            object.attribute_ctr = attribute_ctr
+
+            list_object = drsuapi.DsReplicaObjectListItem()
+            list_object.object = object
+            objects.append(list_object)
 
         req2 = drsuapi.DsAddEntryRequest2()
-        req2.first_object = first_object
+        req2.first_object = objects[0]
+        prev = req2.first_object
+        for o in objects[1:]:
+            prev.next_object = o
+            prev = o
 
         (level, ctr) = ctx.drsuapi.DsAddEntry(ctx.drsuapi_handle, 2, req2)
         if ctr.err_ver != 1:
@@ -306,6 +353,48 @@ class dc_join(object):
             print("DsAddEntry failed with status %s info %s" % (ctr.err_data.status,
                                                                 ctr.err_data.info.extended_err))
             raise RuntimeError("DsAddEntry failed")
+        if ctr.err_data.dir_err != drsuapi.DRSUAPI_DIRERR_OK:
+            print("DsAddEntry failed with dir_err %u" % ctr.err_data.dir_err)
+            raise RuntimeError("DsAddEntry failed")
+        return ctr.objects
+
+
+    def join_add_ntdsdsa(ctx):
+        '''add the ntdsdsa object'''
+        # FIXME: the partition (NC) assignment has to be made dynamic
+        print "Adding %s" % ctx.ntds_dn
+        rec = {
+            "dn" : ctx.ntds_dn,
+            "objectclass" : "nTDSDSA",
+            "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
+            "dMDLocation" : ctx.schema_dn}
+
+        nc_list = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
+
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec["msDS-Behavior-Version"] = str(ctx.behavior_version)
+
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec["msDS-HasDomainNCs"] = ctx.base_dn
+
+        if ctx.RODC:
+            rec["objectCategory"] = "CN=NTDS-DSA-RO,%s" % ctx.schema_dn
+            rec["msDS-HasFullReplicaNCs"] = nc_list
+            rec["options"] = "37"
+            ctx.samdb.add(rec, ["rodc_join:1:1"])
+        else:
+            rec["objectCategory"] = "CN=NTDS-DSA,%s" % ctx.schema_dn
+            rec["HasMasterNCs"]      = nc_list
+            if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+                rec["msDS-HasMasterNCs"] = nc_list
+            rec["options"] = "1"
+            rec["invocationId"] = ndr_pack(ctx.invocation_id)
+            ctx.DsAddEntry([rec])
+
+        # find the GUID of our NTDS DN
+        res = ctx.samdb.search(base=ctx.ntds_dn, scope=ldb.SCOPE_BASE, attrs=["objectGUID"])
+        ctx.ntds_guid = misc.GUID(ctx.samdb.schema_format_value("objectGUID", res[0]["objectGUID"][0]))
+
 
     def join_add_objects(ctx):
         '''add the various objects needed for the join'''
@@ -335,9 +424,11 @@ class dc_join(object):
         rec = {
             "dn": ctx.server_dn,
             "objectclass" : "server",
+            # windows uses 50000000 decimal for systemFlags. A windows hex/decimal mixup bug?
             "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_RENAME |
                                 samba.dsdb.SYSTEM_FLAG_CONFIG_ALLOW_LIMITED_MOVE |
                                 samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
+            # windows seems to add the dnsHostName later
             "dnsHostName" : ctx.dnshostname}
 
         if ctx.acct_dn:
@@ -345,37 +436,12 @@ class dc_join(object):
 
         ctx.samdb.add(rec)
 
-        # FIXME: the partition (NC) assignment has to be made dynamic
-        print "Adding %s" % ctx.ntds_dn
-        rec = {
-            "dn" : ctx.ntds_dn,
-            "objectclass" : "nTDSDSA",
-            "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
-            "dMDLocation" : ctx.schema_dn}
-
-        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
-            rec["msDS-Behavior-Version"] = str(ctx.behavior_version)
-
-        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
-            rec["msDS-HasDomainNCs"] = ctx.base_dn
-
-        if ctx.RODC:
-            rec["objectCategory"] = "CN=NTDS-DSA-RO,%s" % ctx.schema_dn
-            rec["msDS-HasFullReplicaNCs"] = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
-            rec["options"] = "37"
-            ctx.samdb.add(rec, ["rodc_join:1:1"])
-        else:
-            rec["objectCategory"] = "CN=NTDS-DSA,%s" % ctx.schema_dn
-            rec["HasMasterNCs"]      = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
-            if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
-                rec["msDS-HasMasterNCs"] = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
-            rec["options"] = "1"
-            rec["invocationId"] = ndr_pack(misc.GUID(str(uuid.uuid4())))
-            ctx.DsAddEntry(rec)
+        if ctx.subdomain:
+            # the rest is done after replication
+            ctx.ntds_guid = None
+            return
 
-        # find the GUID of our NTDS DN
-        res = ctx.samdb.search(base=ctx.ntds_dn, scope=ldb.SCOPE_BASE, attrs=["objectGUID"])
-        ctx.ntds_guid = misc.GUID(ctx.samdb.schema_format_value("objectGUID", res[0]["objectGUID"][0]))
+        ctx.join_add_ntdsdsa()
 
         if ctx.connection_dn is not None:
             print "Adding %s" % ctx.connection_dn
@@ -423,6 +489,66 @@ class dc_join(object):
                                                          "userAccountControl")
             ctx.samdb.modify(m)
 
+
+    def join_add_objects2(ctx):
+        '''add the various objects needed for the join, for subdomains post replication'''
+
+        print "Adding %s" % ctx.partition_dn
+        # NOTE: windows sends a ntSecurityDescriptor here, we
+        # let it default
+        rec = {
+            "dn" : ctx.partition_dn,
+            "objectclass" : "crossRef",
+            "objectCategory" : "CN=Cross-Ref,%s" % ctx.schema_dn,
+            "nCName" : ctx.base_dn,
+            "nETBIOSName" : ctx.domain_name,
+            "dnsRoot": ctx.dnsdomain,
+            "trustParent" : ctx.parent_partition_dn,
+            "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_CR_NTDS_NC|samba.dsdb.SYSTEM_FLAG_CR_NTDS_DOMAIN)}
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec["msDS-Behavior-Version"] = str(ctx.behavior_version)
+
+        rec2 = {
+            "dn" : ctx.ntds_dn,
+            "objectclass" : "nTDSDSA",
+            "systemFlags" : str(samba.dsdb.SYSTEM_FLAG_DISALLOW_MOVE_ON_DELETE),
+            "dMDLocation" : ctx.schema_dn}
+
+        nc_list = [ ctx.base_dn, ctx.config_dn, ctx.schema_dn ]
+
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec2["msDS-Behavior-Version"] = str(ctx.behavior_version)
+
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec2["msDS-HasDomainNCs"] = ctx.base_dn
+
+        rec2["objectCategory"] = "CN=NTDS-DSA,%s" % ctx.schema_dn
+        rec2["HasMasterNCs"]      = nc_list
+        if ctx.behavior_version >= samba.dsdb.DS_DOMAIN_FUNCTION_2003:
+            rec2["msDS-HasMasterNCs"] = nc_list
+        rec2["options"] = "1"
+        rec2["invocationId"] = ndr_pack(ctx.invocation_id)
+
+        objects = ctx.DsAddEntry([rec, rec2])
+        if len(objects) != 2:
+            raise DCJoinException("Expected 2 objects from DsAddEntry")
+
+        ctx.ntds_guid = objects[1].guid
+
+        print("Replicating partition DN")
+        ctx.repl.replicate(ctx.partition_dn,
+                           misc.GUID("00000000-0000-0000-0000-000000000000"),
+                           ctx.ntds_guid,
+                           exop=drsuapi.DRSUAPI_EXOP_REPL_OBJ,
+                           replica_flags=drsuapi.DRSUAPI_DRS_WRIT_REP)
+
+        print("Replicating NTDS DN")
+        ctx.repl.replicate(ctx.ntds_dn,
+                           misc.GUID("00000000-0000-0000-0000-000000000000"),
+                           ctx.ntds_guid,
+                           exop=drsuapi.DRSUAPI_EXOP_REPL_OBJ,
+                           replica_flags=drsuapi.DRSUAPI_DRS_WRIT_REP)
+
     def join_provision(ctx):
         '''provision the local SAM'''
 
@@ -450,7 +576,23 @@ class dc_join(object):
     def join_provision_own_domain(ctx):
         '''provision the local SAM'''
 
-        print "Calling bare provision"
+        # we now operate exclusively on the local database, which
+        # we need to reopen in order to get the newly created schema
+        print("Reconnecting to local samdb")
+        ctx.samdb = SamDB(url=ctx.local_samdb.url,
+                          session_info=system_session(),
+                          lp=ctx.local_samdb.lp,
+                          global_schema=False)
+        ctx.samdb.set_invocation_id(str(ctx.invocation_id))
+        ctx.local_samdb = ctx.samdb
+
+        print("Finding domain GUID from ncName")
+        res = ctx.local_samdb.search(base=ctx.partition_dn, scope=ldb.SCOPE_BASE, attrs=['ncName'],
+                                     controls=["extended_dn:1:1"])
+        domguid = str(misc.GUID(ldb.Dn(ctx.samdb, res[0]['ncName'][0]).get_extended_component('GUID')))
+        print("Got domain GUID %s" % domguid)
+
+        print("Calling own domain provision")
 
         logger = logging.getLogger("provision")
         logger.addHandler(logging.StreamHandler(sys.stdout))
@@ -459,10 +601,11 @@ class dc_join(object):
 
         presult = provision_fill(ctx.local_samdb, secrets_ldb,
                                  logger, ctx.names, ctx.paths, domainsid=security.dom_sid(ctx.domsid),
+                                 domainguid=domguid,
                                  targetdir=ctx.targetdir, samdb_fill=FILL_SUBDOMAIN,
                                  machinepass=ctx.acct_pass, serverrole="domain controller",
-                                 lp=ctx.lp)
-        print "Provision OK for domain %s" % ctx.names.dnsdomain
+                                 lp=ctx.lp, hostip=ctx.names.hostip, hostip6=ctx.names.hostip6)
+        print("Provision OK for domain %s" % ctx.names.dnsdomain)
 
 
     def join_replicate(ctx):
@@ -472,7 +615,11 @@ class dc_join(object):
         ctx.local_samdb.transaction_start()
         try:
             source_dsa_invocation_id = misc.GUID(ctx.samdb.get_invocation_id())
-            destination_dsa_guid = ctx.ntds_guid
+            if ctx.ntds_guid is None:
+                print("Using DS_BIND_GUID_W2K3")
+                destination_dsa_guid = misc.GUID(drsuapi.DRSUAPI_DS_BIND_GUID_W2K3)
+            else:
+                destination_dsa_guid = ctx.ntds_guid
 
             if ctx.RODC:
                 repl_creds = Credentials()
@@ -507,6 +654,9 @@ class dc_join(object):
                 repl.replicate(ctx.new_krbtgt_dn, source_dsa_invocation_id,
                         destination_dsa_guid,
                         exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
+            ctx.repl = repl
+            ctx.source_dsa_invocation_id = source_dsa_invocation_id
+            ctx.destination_dsa_guid = destination_dsa_guid
 
             print "Committing SAM database"
         except:
@@ -527,6 +677,9 @@ class dc_join(object):
                                                 ldb.FLAG_MOD_REPLACE, "dsServiceName")
         ctx.local_samdb.modify(m)
 
+        if ctx.subdomain:
+            return
+
         secrets_ldb = Ldb(ctx.paths.secrets, session_info=system_session(), lp=ctx.lp)
 
         print "Setting up secrets database"
@@ -539,6 +692,119 @@ class dc_join(object):
                             secure_channel_type=ctx.secure_channel_type,
                             key_version_number=ctx.key_version_number)
 
+    def join_setup_trusts(ctx):
+        '''provision the local SAM'''
+
+        def arcfour_encrypt(key, data):
+            from Crypto.Cipher import ARC4
+            c = ARC4.new(key)
+            return c.encrypt(data)
+
+        def string_to_array(string):
+            blob = [0] * len(string)
+
+            for i in range(len(string)):
+                blob[i] = ord(string[i])
+
+            return blob
+
+        print "Setup domain trusts with server %s" % ctx.server
+        binding_options = ""  # why doesn't signing work gere? w2k8r2 claims no session key
+        lsaconn = lsa.lsarpc("ncacn_np:%s[%s]" % (ctx.server, binding_options),
+                             ctx.lp, ctx.creds)
+
+        objectAttr = lsa.ObjectAttribute()
+        objectAttr.sec_qos = lsa.QosInfo()
+
+        pol_handle = lsaconn.OpenPolicy2(''.decode('utf-8'),
+                                         objectAttr, security.SEC_FLAG_MAXIMUM_ALLOWED)
+
+        info = lsa.TrustDomainInfoInfoEx()
+        info.domain_name.string = ctx.dnsdomain
+        info.netbios_name.string = ctx.domain_name
+        info.sid = security.dom_sid(ctx.domsid)
+        info.trust_direction = lsa.LSA_TRUST_DIRECTION_INBOUND | lsa.LSA_TRUST_DIRECTION_OUTBOUND
+        info.trust_type = lsa.LSA_TRUST_TYPE_UPLEVEL
+        info.trust_attributes = lsa.LSA_TRUST_ATTRIBUTE_WITHIN_FOREST
+
+        try:
+            oldname = lsa.String()
+            oldname.string = ctx.dnsdomain
+            oldinfo = lsaconn.QueryTrustedDomainInfoByName(pol_handle, oldname,
+                                                           lsa.LSA_TRUSTED_DOMAIN_INFO_FULL_INFO)
+            print("Removing old trust record for %s (SID %s)" % (ctx.dnsdomain, oldinfo.info_ex.sid))
+            lsaconn.DeleteTrustedDomain(pol_handle, oldinfo.info_ex.sid)
+        except RuntimeError:
+            pass
+
+        password_blob = string_to_array(ctx.trustdom_pass.encode('utf-16-le'))
+
+        clear_value = drsblobs.AuthInfoClear()
+        clear_value.size = len(password_blob)
+        clear_value.password = password_blob
+
+        clear_authentication_information = drsblobs.AuthenticationInformation()
+        clear_authentication_information.LastUpdateTime = samba.unix2nttime(int(time.time()))
+        clear_authentication_information.AuthType = lsa.TRUST_AUTH_TYPE_CLEAR
+        clear_authentication_information.AuthInfo = clear_value
+
+        authentication_information_array = drsblobs.AuthenticationInformationArray()
+        authentication_information_array.count = 1
+        authentication_information_array.array = [clear_authentication_information]
+
+        outgoing = drsblobs.trustAuthInOutBlob()
+        outgoing.count = 1
+        outgoing.current = authentication_information_array
+
+        trustpass = drsblobs.trustDomainPasswords()
+        confounder = [3] * 512
+
+        for i in range(512):
+            confounder[i] = random.randint(0, 255)
+
+        trustpass.confounder = confounder
+
+        trustpass.outgoing = outgoing
+        trustpass.incoming = outgoing
+
+        trustpass_blob = ndr_pack(trustpass)
+
+        encrypted_trustpass = arcfour_encrypt(lsaconn.session_key, trustpass_blob)
+
+        auth_blob = lsa.DATA_BUF2()
+        auth_blob.size = len(encrypted_trustpass)
+        auth_blob.data = string_to_array(encrypted_trustpass)
+
+        auth_info = lsa.TrustDomainInfoAuthInfoInternal()
+        auth_info.auth_blob = auth_blob
+
+        trustdom_handle = lsaconn.CreateTrustedDomainEx2(pol_handle,
+                                                         info,
+                                                         auth_info,
+                                                         security.SEC_STD_DELETE)
+
+        rec = {
+            "dn" : "cn=%s,cn=system,%s" % (ctx.parent_dnsdomain, ctx.base_dn),
+            "objectclass" : "trustedDomain",
+            "trustType" : str(info.trust_type),
+            "trustAttributes" : str(info.trust_attributes),
+            "trustDirection" : str(info.trust_direction),
+            "flatname" : ctx.parent_domain_name,
+            "trustPartner" : ctx.parent_dnsdomain,
+            "trustAuthIncoming" : ndr_pack(outgoing),
+            "trustAuthOutgoing" : ndr_pack(outgoing)
+            }
+        ctx.local_samdb.add(rec)
+
+        rec = {
+            "dn" : "cn=%s$,cn=users,%s" % (ctx.parent_domain_name, ctx.base_dn),
+            "objectclass" : "user",
+            "userAccountControl" : str(samba.dsdb.UF_INTERDOMAIN_TRUST_ACCOUNT),
+            "clearTextPassword" : ctx.trustdom_pass.encode('utf-16-le')
+            }
+        ctx.local_samdb.add(rec)
+
+
     def do_join(ctx):
         ctx.cleanup_old_join()
         try:
@@ -546,12 +812,13 @@ class dc_join(object):
             ctx.join_provision()
             ctx.join_replicate()
             if ctx.subdomain:
+                ctx.join_add_objects2()
                 ctx.join_provision_own_domain()
-            else:
-                ctx.join_finalise()
+                ctx.join_setup_trusts()
+            ctx.join_finalise()
         except Exception:
             print "Join failed - cleaning up"
-            ctx.cleanup_old_join()
+            #ctx.cleanup_old_join()
             raise
 
 
@@ -640,13 +907,18 @@ def join_subdomain(server=None, creds=None, lp=None, site=None, netbios_name=Non
     """join as a DC"""
     ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, parent_domain)
     ctx.subdomain = True
+    ctx.parent_domain_name = ctx.domain_name
     ctx.domain_name = netbios_domain
     ctx.realm = dnsdomain
+    ctx.parent_dnsdomain = ctx.dnsdomain
+    ctx.parent_partition_dn = ctx.get_parent_partition_dn()
     ctx.dnsdomain = dnsdomain
+    ctx.partition_dn = "CN=%s,CN=Partitions,%s" % (ctx.domain_name, ctx.config_dn)
     ctx.base_dn = samba.dn_from_dns_name(dnsdomain)
     ctx.domsid = str(security.random_sid())
     ctx.acct_dn = None
     ctx.dnshostname = "%s.%s" % (ctx.myname, ctx.dnsdomain)
+    ctx.trustdom_pass = samba.generate_random_password(128, 128)
 
     ctx.userAccountControl = samba.dsdb.UF_SERVER_TRUST_ACCOUNT | samba.dsdb.UF_TRUSTED_FOR_DELEGATION