samba3-python: Add methods to get any entry (user/group) and its sid from idmap
[idra/samba.git] / source4 / scripting / python / samba / join.py
index 12df25a866a2cfcf27fea0e7401e5d1d0dcda64f..00f2c54211f8abe7bb7720a91840213d66a7f567 100644 (file)
@@ -27,7 +27,7 @@ import ldb, samba, sys, os, uuid
 from samba.ndr import ndr_pack
 from samba.dcerpc import security, drsuapi, misc, nbt
 from samba.credentials import Credentials, DONT_USE_KERBEROS
-from samba.provision import secretsdb_self_join, provision, FILL_DRS, find_setup_dir
+from samba.provision import secretsdb_self_join, provision, FILL_DRS
 from samba.schema import Schema
 from samba.net import Net
 import logging
@@ -36,11 +36,17 @@ import talloc
 # this makes debugging easier
 talloc.enable_null_tracking()
 
-class dc_join:
+class DCJoinException(Exception):
+
+    def __init__(self, msg):
+        super(DCJoinException, self).__init__("Can't join, error: %s" % msg)
+
+
+class dc_join(object):
     '''perform a DC join'''
 
-    def __init__(ctx, server=None, creds=None, lp=None, site=None, netbios_name=None,
-                 targetdir=None, domain=None):
+    def __init__(ctx, server=None, creds=None, lp=None, site=None,
+            netbios_name=None, targetdir=None, domain=None):
         ctx.creds = creds
         ctx.lp = lp
         ctx.site = site
@@ -61,6 +67,12 @@ class dc_join:
                           session_info=system_session(),
                           credentials=ctx.creds, lp=ctx.lp)
 
+        try:
+            ctx.samdb.search(scope=ldb.SCOPE_ONELEVEL, attrs=["dn"])
+        except ldb.LdbError, (enum, estr):
+            raise DCJoinException(estr)
+
+
         ctx.myname = netbios_name
         ctx.samname = "%s$" % ctx.myname
         ctx.base_dn = str(ctx.samdb.get_default_basedn())
@@ -99,7 +111,6 @@ class dc_join:
 
         ctx.acct_dn = "CN=%s,OU=Domain Controllers,%s" % (ctx.myname, ctx.base_dn)
 
-        ctx.setup_dir = find_setup_dir()
         ctx.tmp_samdb = None
 
         ctx.SPNs = [ "HOST/%s" % ctx.myname,
@@ -120,14 +131,14 @@ class dc_join:
         if recursive:
             try:
                 res = ctx.samdb.search(base=dn, scope=ldb.SCOPE_ONELEVEL, attrs=["dn"])
-            except:
+            except Exception:
                 return
             for r in res:
                 ctx.del_noerror(r.dn, recursive=True)
         try:
             ctx.samdb.delete(dn)
             print "Deleted %s" % dn
-        except:
+        except Exception:
             pass
 
     def cleanup_old_join(ctx):
@@ -136,7 +147,7 @@ class dc_join:
             # find the krbtgt link
             print("checking samaccountname")
             res = ctx.samdb.search(base=ctx.samdb.get_default_basedn(),
-                                   expression='samAccountName=%s' % ctx.samname,
+                                   expression='samAccountName=%s' % ldb.binary_encode(ctx.samname),
                                    attrs=["msDS-krbTgtLink"])
             if res:
                 ctx.del_noerror(res[0].dn, recursive=True)
@@ -151,16 +162,15 @@ class dc_join:
             if res:
                 ctx.new_krbtgt_dn = res[0]["msDS-Krbtgtlink"][0]
                 ctx.del_noerror(ctx.new_krbtgt_dn)
-        except:
+        except Exception:
             pass
 
     def find_dc(ctx, domain):
         '''find a writeable DC for the given domain'''
         try:
             ctx.cldap_ret = ctx.net.finddc(domain, nbt.NBT_SERVER_LDAP | nbt.NBT_SERVER_DS | nbt.NBT_SERVER_WRITABLE)
-        except Exception, reason:
-            print("Failed to find a writeable DC for domain '%s': %s" % (domain, reason))
-            sys.exit(1)
+        except Exception:
+            raise Exception("Failed to find a writeable DC for domain '%s'" % domain)
         if ctx.cldap_ret.client_site is not None and ctx.cldap_ret.client_site != "":
             ctx.site = ctx.cldap_ret.client_site
         return ctx.cldap_ret.pdc_dns_name
@@ -199,8 +209,10 @@ class dc_join:
         '''check if a DN exists'''
         try:
             res = ctx.samdb.search(base=dn, scope=ldb.SCOPE_BASE, attrs=[])
-        except ldb.LdbError, (ERR_NO_SUCH_OBJECT, _):
-            return False
+        except ldb.LdbError, (enum, estr):
+            if enum == ldb.ERR_NO_SUCH_OBJECT:
+                return False
+            raise
         return True
 
     def add_krbtgt_account(ctx):
@@ -235,7 +247,7 @@ class dc_join:
     def drsuapi_connect(ctx):
         '''make a DRSUAPI connection to the server'''
         binding_options = "seal"
-        if ctx.lp.get("log level") >= 5:
+        if int(ctx.lp.get("log level")) >= 5:
             binding_options += ",print"
         binding_string = "ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options)
         ctx.drsuapi = drsuapi.drsuapi(binding_string, ctx.lp, ctx.creds)
@@ -243,9 +255,7 @@ class dc_join:
 
     def create_tmp_samdb(ctx):
         '''create a temporary samdb object for schema queries'''
-        def setup_path(file):
-            return os.path.join(ctx.setup_dir, file)
-        ctx.tmp_schema = Schema(setup_path, security.dom_sid(ctx.domsid),
+        ctx.tmp_schema = Schema(security.dom_sid(ctx.domsid),
                                 schemadn=ctx.schema_dn)
         ctx.tmp_samdb = SamDB(session_info=system_session(), url=None, auto_connect=False,
                               credentials=ctx.creds, lp=ctx.lp, global_schema=False,
@@ -398,7 +408,7 @@ class dc_join:
         ctx.samdb.modify(m)
 
         print "Setting account password for %s" % ctx.samname
-        ctx.samdb.setpassword("(&(objectClass=user)(sAMAccountName=%s))" % ctx.samname,
+        ctx.samdb.setpassword("(&(objectClass=user)(sAMAccountName=%s))" % ldb.binary_encode(ctx.samname),
                               ctx.acct_pass,
                               force_change_at_next_login=False,
                               username=ctx.samname)
@@ -422,7 +432,7 @@ class dc_join:
         logger.addHandler(logging.StreamHandler(sys.stdout))
         smbconf = ctx.lp.configfile
 
-        presult = provision(ctx.setup_dir, logger, system_session(), None,
+        presult = provision(logger, system_session(), None,
                             smbconf=smbconf, targetdir=ctx.targetdir, samdb_fill=FILL_DRS,
                             realm=ctx.realm, rootdn=ctx.root_dn, domaindn=ctx.base_dn,
                             schemadn=ctx.schema_dn,
@@ -442,50 +452,61 @@ class dc_join:
 
         print "Starting replication"
         ctx.local_samdb.transaction_start()
-
-        source_dsa_invocation_id = misc.GUID(ctx.samdb.get_invocation_id())
-        destination_dsa_guid = ctx.ntds_guid
-
-        if ctx.RODC:
-            repl_creds = Credentials()
-            repl_creds.guess(ctx.lp)
-            repl_creds.set_kerberos_state(DONT_USE_KERBEROS)
-            repl_creds.set_username(ctx.samname)
-            repl_creds.set_password(ctx.acct_pass)
+        try:
+            source_dsa_invocation_id = misc.GUID(ctx.samdb.get_invocation_id())
+            destination_dsa_guid = ctx.ntds_guid
+
+            if ctx.RODC:
+                repl_creds = Credentials()
+                repl_creds.guess(ctx.lp)
+                repl_creds.set_kerberos_state(DONT_USE_KERBEROS)
+                repl_creds.set_username(ctx.samname)
+                repl_creds.set_password(ctx.acct_pass)
+            else:
+                repl_creds = ctx.creds
+
+            binding_options = "seal"
+            if int(ctx.lp.get("log level")) >= 5:
+                binding_options += ",print"
+            repl = drs_utils.drs_Replicate(
+                "ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options),
+                ctx.lp, repl_creds, ctx.local_samdb)
+
+            repl.replicate(ctx.schema_dn, source_dsa_invocation_id,
+                    destination_dsa_guid, schema=True, rodc=ctx.RODC,
+                    replica_flags=ctx.replica_flags)
+            repl.replicate(ctx.config_dn, source_dsa_invocation_id,
+                    destination_dsa_guid, rodc=ctx.RODC,
+                    replica_flags=ctx.replica_flags)
+            repl.replicate(ctx.base_dn, source_dsa_invocation_id,
+                    destination_dsa_guid, rodc=ctx.RODC,
+                    replica_flags=ctx.domain_replica_flags)
+            if ctx.RODC:
+                repl.replicate(ctx.acct_dn, source_dsa_invocation_id,
+                        destination_dsa_guid,
+                        exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
+                repl.replicate(ctx.new_krbtgt_dn, source_dsa_invocation_id,
+                        destination_dsa_guid,
+                        exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
+
+            print "Committing SAM database"
+        except:
+            ctx.local_samdb.transaction_cancel()
+            raise
         else:
-            repl_creds = ctx.creds
-
-        binding_options = "seal"
-        if ctx.lp.get("debug level") >= 5:
-            binding_options += ",print"
-        repl = drs_utils.drs_Replicate("ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options),
-                                       ctx.lp, repl_creds, ctx.local_samdb)
-
-        repl.replicate(ctx.schema_dn, source_dsa_invocation_id, destination_dsa_guid,
-                       schema=True, rodc=ctx.RODC,
-                       replica_flags=ctx.replica_flags)
-        repl.replicate(ctx.config_dn, source_dsa_invocation_id, destination_dsa_guid,
-                       rodc=ctx.RODC, replica_flags=ctx.replica_flags)
-        repl.replicate(ctx.base_dn, source_dsa_invocation_id, destination_dsa_guid,
-                       rodc=ctx.RODC, replica_flags=ctx.replica_flags)
-        if ctx.RODC:
-            repl.replicate(ctx.acct_dn, source_dsa_invocation_id, destination_dsa_guid,
-                           exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
-            repl.replicate(ctx.new_krbtgt_dn, source_dsa_invocation_id, destination_dsa_guid,
-                           exop=drsuapi.DRSUAPI_EXOP_REPL_SECRET, rodc=True)
-
-        print "Committing SAM database"
-        ctx.local_samdb.transaction_commit()
+            ctx.local_samdb.transaction_commit()
 
 
     def join_finalise(ctx):
         '''finalise the join, mark us synchronised and setup secrets db'''
 
-        print "Setting isSynchronized"
+        print "Setting isSynchronized and dsServiceName"
         m = ldb.Message()
-        m.dn = ldb.Dn(ctx.samdb, '@ROOTDSE')
+        m.dn = ldb.Dn(ctx.local_samdb, '@ROOTDSE')
         m["isSynchronized"] = ldb.MessageElement("TRUE", ldb.FLAG_MOD_REPLACE, "isSynchronized")
-        ctx.samdb.modify(m)
+        m["dsServiceName"] = ldb.MessageElement("<GUID=%s>" % str(ctx.ntds_guid),
+                                                ldb.FLAG_MOD_REPLACE, "dsServiceName")
+        ctx.local_samdb.modify(m)
 
         secrets_ldb = Ldb(ctx.paths.secrets, session_info=system_session(), lp=ctx.lp)
 
@@ -506,14 +527,14 @@ class dc_join:
             ctx.join_provision()
             ctx.join_replicate()
             ctx.join_finalise()
-        except:
+        except Exception:
             print "Join failed - cleaning up"
             ctx.cleanup_old_join()
             raise
 
 
 def join_RODC(server=None, creds=None, lp=None, site=None, netbios_name=None,
-              targetdir=None, domain=None):
+              targetdir=None, domain=None, domain_critical_only=False):
     """join as a RODC"""
 
     ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, domain)
@@ -548,6 +569,10 @@ def join_RODC(server=None, creds=None, lp=None, site=None, netbios_name=None,
                            drsuapi.DRSUAPI_DRS_NEVER_SYNCED |
                            drsuapi.DRSUAPI_DRS_SPECIAL_SECRET_PROCESSING |
                            drsuapi.DRSUAPI_DRS_GET_ALL_GROUP_MEMBERSHIP)
+    ctx.domain_replica_flags = ctx.replica_flags
+    if domain_critical_only:
+        ctx.domain_replica_flags |= drsuapi.DRSUAPI_DRS_CRITICAL_ONLY
+
     ctx.do_join()
 
 
@@ -555,7 +580,7 @@ def join_RODC(server=None, creds=None, lp=None, site=None, netbios_name=None,
 
 
 def join_DC(server=None, creds=None, lp=None, site=None, netbios_name=None,
-            targetdir=None, domain=None):
+            targetdir=None, domain=None, domain_critical_only=False):
     """join as a DC"""
     ctx = dc_join(server, creds, lp, site, netbios_name, targetdir, domain)
 
@@ -569,6 +594,9 @@ def join_DC(server=None, creds=None, lp=None, site=None, netbios_name=None,
                          drsuapi.DRSUAPI_DRS_PER_SYNC |
                          drsuapi.DRSUAPI_DRS_FULL_SYNC_IN_PROGRESS |
                          drsuapi.DRSUAPI_DRS_NEVER_SYNCED)
+    ctx.domain_replica_flags = ctx.replica_flags
+    if domain_critical_only:
+        ctx.domain_replica_flags |= drsuapi.DRSUAPI_DRS_CRITICAL_ONLY
 
     ctx.do_join()
     print "Joined domain %s (SID %s) as a DC" % (ctx.domain_name, ctx.domsid)