python/samba/netcmd: changes for samab.tests.samba_tool.computer
[samba.git] / python / samba / samdb.py
index 5b04c982673343b18e589f6ff251ff83a26652a6..fa3e4b1b33405a41849fcf766bdcae389b8885de 100644 (file)
@@ -27,14 +27,21 @@ import ldb
 import time
 import base64
 import os
+import re
 from samba import dsdb, dsdb_dns
 from samba.ndr import ndr_unpack, ndr_pack
 from samba.dcerpc import drsblobs, misc
 from samba.common import normalise_int32
+from samba.compat import text_type
+from samba.dcerpc import security
 
 __docformat__ = "restructuredText"
 
 
+def get_default_backend_store():
+    return "tdb"
+
+
 class SamDB(samba.Ldb):
     """The SAM database."""
 
@@ -54,8 +61,8 @@ class SamDB(samba.Ldb):
         self.url = url
 
         super(SamDB, self).__init__(url=url, lp=lp, modules_dir=modules_dir,
-            session_info=session_info, credentials=credentials, flags=flags,
-            options=options)
+                                    session_info=session_info, credentials=credentials, flags=flags,
+                                    options=options)
 
         if global_schema:
             dsdb._dsdb_set_global_schema(self)
@@ -70,7 +77,7 @@ class SamDB(samba.Ldb):
         self.url = url
 
         super(SamDB, self).connect(url=url, flags=flags,
-                options=options)
+                                   options=options)
 
     def am_rodc(self):
         '''return True if we are an RODC'''
@@ -84,6 +91,10 @@ class SamDB(samba.Ldb):
         '''return the domain DN'''
         return str(self.get_default_basedn())
 
+    def schema_dn(self):
+        '''return the schema partition dn'''
+        return str(self.get_schema_basedn())
+
     def disable_account(self, search_filter):
         """Disables an account
 
@@ -190,8 +201,8 @@ pwdLastSet: 0
         # The new user record. Note the reliance on the SAMLDB module which
         # fills in the default informations
         ldbmessage = {"dn": group_dn,
-            "sAMAccountName": groupname,
-            "objectClass": "group"}
+                      "sAMAccountName": groupname,
+                      "objectClass": "group"}
 
         if grouptype is not None:
             ldbmessage["groupType"] = normalise_int32(grouptype)
@@ -227,7 +238,7 @@ pwdLastSet: 0
         self.transaction_start()
         try:
             targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
-                               expression=groupfilter, attrs=[])
+                                      expression=groupfilter, attrs=[])
             if len(targetgroup) == 0:
                 raise Exception('Unable to find group "%s"' % groupname)
             assert(len(targetgroup) == 1)
@@ -239,7 +250,7 @@ pwdLastSet: 0
             self.transaction_commit()
 
     def add_remove_group_members(self, groupname, members,
-                                  add_members_operation=True):
+                                 add_members_operation=True):
         """Adds or removes group members
 
         :param groupname: Name of the target group
@@ -254,7 +265,7 @@ pwdLastSet: 0
         self.transaction_start()
         try:
             targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
-                               expression=groupfilter, attrs=['member'])
+                                      expression=groupfilter, attrs=['member'])
             if len(targetgroup) == 0:
                 raise Exception('Unable to find group "%s"' % groupname)
             assert(len(targetgroup) == 1)
@@ -269,25 +280,40 @@ changetype: modify
             for member in members:
                 filter = ('(&(sAMAccountName=%s)(|(objectclass=user)'
                           '(objectclass=group)))' % ldb.binary_encode(member))
+                foreign_msg = None
+                try:
+                    membersid = security.dom_sid(member)
+                except TypeError as e:
+                    membersid = None
+
+                if membersid is not None:
+                    filter = '(objectSid=%s)' % str(membersid)
+                    dn_str = "<SID=%s>" % str(membersid)
+                    foreign_msg = ldb.Message()
+                    foreign_msg.dn = ldb.Dn(self, dn_str)
+
                 targetmember = self.search(base=self.domain_dn(),
                                            scope=ldb.SCOPE_SUBTREE,
                                            expression="%s" % filter,
                                            attrs=[])
 
+                if len(targetmember) == 0 and foreign_msg is not None:
+                    targetmember = [foreign_msg]
                 if len(targetmember) != 1:
                     raise Exception('Unable to find "%s". Operation cancelled.' % member)
+                targetmember_dn = targetmember[0].dn.extended_str(1)
 
-                if add_members_operation is True and (targetgroup[0].get('member') is None or str(targetmember[0].dn) not in targetgroup[0]['member']):
+                if add_members_operation is True and (targetgroup[0].get('member') is None or str(targetmember_dn) not in targetgroup[0]['member']):
                     modified = True
                     addtargettogroup += """add: member
 member: %s
-""" % (str(targetmember[0].dn))
+""" % (str(targetmember_dn))
 
-                elif add_members_operation is False and (targetgroup[0].get('member') is not None and str(targetmember[0].dn) in targetgroup[0]['member']):
+                elif add_members_operation is False and (targetgroup[0].get('member') is not None and targetmember_dn in targetgroup[0]['member']):
                     modified = True
                     addtargettogroup += """delete: member
 member: %s
-""" % (str(targetmember[0].dn))
+""" % (str(targetmember_dn))
 
             if modified is True:
                 self.modify_ldif(addtargettogroup)
@@ -299,15 +325,15 @@ member: %s
             self.transaction_commit()
 
     def newuser(self, username, password,
-            force_password_change_at_next_login_req=False,
-            useusernameascn=False, userou=None, surname=None, givenname=None,
-            initials=None, profilepath=None, scriptpath=None, homedrive=None,
-            homedirectory=None, jobtitle=None, department=None, company=None,
-            description=None, mailaddress=None, internetaddress=None,
-            telephonenumber=None, physicaldeliveryoffice=None, sd=None,
-            setpassword=True, uidnumber=None, gidnumber=None, gecos=None,
-            loginshell=None, uid=None, nisdomain=None, unixhome=None,
-            smartcard_required=False):
+                force_password_change_at_next_login_req=False,
+                useusernameascn=False, userou=None, surname=None, givenname=None,
+                initials=None, profilepath=None, scriptpath=None, homedrive=None,
+                homedirectory=None, jobtitle=None, department=None, company=None,
+                description=None, mailaddress=None, internetaddress=None,
+                telephonenumber=None, physicaldeliveryoffice=None, sd=None,
+                setpassword=True, uidnumber=None, gidnumber=None, gecos=None,
+                loginshell=None, uid=None, nisdomain=None, unixhome=None,
+                smartcard_required=False):
         """Adds a new user with additional parameters
 
         :param username: Name of the new user
@@ -369,7 +395,8 @@ member: %s
                       "objectClass": "user"}
 
         if smartcard_required:
-            ldbmessage["userAccountControl"] = str(dsdb.UF_NORMAL_ACCOUNT|dsdb.UF_SMARTCARD_REQUIRED)
+            ldbmessage["userAccountControl"] = str(dsdb.UF_NORMAL_ACCOUNT |
+                                                   dsdb.UF_SMARTCARD_REQUIRED)
             setpassword = False
 
         if surname is not None:
@@ -426,7 +453,7 @@ member: %s
 
         ldbmessage2 = None
         if any(map(lambda b: b is not None, (uid, uidnumber, gidnumber, gecos,
-                loginshell, nisdomain, unixhome))):
+                                             loginshell, nisdomain, unixhome))):
             ldbmessage2 = ldb.Message()
             ldbmessage2.dn = ldb.Dn(self, user_dn)
             if uid is not None:
@@ -459,7 +486,9 @@ member: %s
 
             # Sets the password for it
             if setpassword:
-                self.setpassword("(samAccountName=%s)" % ldb.binary_encode(username), password,
+                self.setpassword(("(distinguishedName=%s)" %
+                                  ldb.binary_encode(user_dn)),
+                                 password,
                                  force_password_change_at_next_login_req)
         except:
             self.transaction_cancel()
@@ -467,6 +496,65 @@ member: %s
         else:
             self.transaction_commit()
 
+    def newcomputer(self, computername, computerou=None, description=None,
+                    prepare_oldjoin=False, ip_address_list=None,
+                    service_principal_name_list=None):
+        """Adds a new user with additional parameters
+
+        :param computername: Name of the new computer
+        :param computerou: Object container for new computer
+        :param description: Description of the new computer
+        :param prepare_oldjoin: Preset computer password for oldjoin mechanism
+        :param ip_address_list: ip address list for DNS A or AAAA record
+        :param service_principal_name_list: string list of servicePincipalName
+        """
+
+        cn = re.sub(r"\$$", "", computername)
+        if cn.count('$'):
+            raise Exception('Illegal computername "%s"' % computername)
+        samaccountname = "%s$" % cn
+
+        computercontainer_dn = "CN=Computers,%s" % self.domain_dn()
+        if computerou:
+            computercontainer_dn = self.normalize_dn_in_domain(computerou)
+
+        computer_dn = "CN=%s,%s" % (cn, computercontainer_dn)
+
+        ldbmessage = {"dn": computer_dn,
+                      "sAMAccountName": samaccountname,
+                      "objectClass": "computer",
+                      }
+
+        if description is not None:
+            ldbmessage["description"] = description
+
+        if service_principal_name_list:
+            ldbmessage["servicePrincipalName"] = service_principal_name_list
+
+        accountcontrol = str(dsdb.UF_WORKSTATION_TRUST_ACCOUNT |
+                             dsdb.UF_ACCOUNTDISABLE)
+        if prepare_oldjoin:
+            accountcontrol = str(dsdb.UF_WORKSTATION_TRUST_ACCOUNT)
+        ldbmessage["userAccountControl"] = accountcontrol
+
+        if ip_address_list:
+            ldbmessage['dNSHostName'] = '{}.{}'.format(
+                cn, self.domain_dns_name())
+
+        self.transaction_start()
+        try:
+            self.add(ldbmessage)
+
+            if prepare_oldjoin:
+                password = cn.lower()
+                self.setpassword(("(distinguishedName=%s)" %
+                                  ldb.binary_encode(computer_dn)),
+                                 password, False)
+        except:
+            self.transaction_cancel()
+            raise
+        else:
+            self.transaction_commit()
 
     def deleteuser(self, username):
         """Deletes a user
@@ -490,7 +578,7 @@ member: %s
             self.transaction_commit()
 
     def setpassword(self, search_filter, password,
-            force_change_at_next_login=False, username=None):
+                    force_change_at_next_login=False, username=None):
         """Sets the password for a user
 
         :param search_filter: LDAP filter to find the user (eg
@@ -507,19 +595,23 @@ member: %s
             if len(res) > 1:
                 raise Exception('Matched %u multiple users with filter "%s"' % (len(res), search_filter))
             user_dn = res[0].dn
-            pw = unicode('"' + password.encode('utf-8') + '"', 'utf-8').encode('utf-16-le')
+            if not isinstance(password, text_type):
+                pw = password.decode('utf-8')
+            else:
+                pw = password
+            pw = ('"' + pw + '"').encode('utf-16-le')
             setpw = """
 dn: %s
 changetype: modify
 replace: unicodePwd
 unicodePwd:: %s
-""" % (user_dn, base64.b64encode(pw))
+""" % (user_dn, base64.b64encode(pw).decode('utf-8'))
 
             self.modify_ldif(setpw)
 
             if force_change_at_next_login:
                 self.force_password_change_at_next_login(
-                  "(distinguishedName=" + str(user_dn) + ")")
+                    "(distinguishedName=" + str(user_dn) + ")")
 
             #  modify the userAccountControl to remove the disabled bit
             self.enable_account(search_filter)
@@ -540,8 +632,8 @@ unicodePwd:: %s
         self.transaction_start()
         try:
             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
-                          expression=search_filter,
-                          attrs=["userAccountControl", "accountExpires"])
+                              expression=search_filter,
+                              attrs=["userAccountControl", "accountExpires"])
             if len(res) == 0:
                 raise Exception('Unable to find user "%s"' % search_filter)
             assert(len(res) == 1)
@@ -584,7 +676,7 @@ accountExpires: %u
         return dsdb._samdb_get_domain_sid(self)
 
     domain_sid = property(get_domain_sid, set_domain_sid,
-        "SID for the domain")
+                          "SID for the domain")
 
     def set_invocation_id(self, invocation_id):
         """Set the invocation id for this SamDB handle.
@@ -598,16 +690,16 @@ accountExpires: %u
         return dsdb._samdb_ntds_invocation_id(self)
 
     invocation_id = property(get_invocation_id, set_invocation_id,
-        "Invocation ID GUID")
+                             "Invocation ID GUID")
 
     def get_oid_from_attid(self, attid):
         return dsdb._dsdb_get_oid_from_attid(self, attid)
 
     def get_attid_from_lDAPDisplayName(self, ldap_display_name,
-            is_schema_nc=False):
+                                       is_schema_nc=False):
         '''return the attribute ID for a LDAP attribute as an integer as found in DRSUAPI'''
         return dsdb._dsdb_get_attid_from_lDAPDisplayName(self,
-            ldap_display_name, is_schema_nc)
+                                                         ldap_display_name, is_schema_nc)
 
     def get_syntax_oid_from_lDAPDisplayName(self, ldap_display_name):
         '''return the syntax OID for a LDAP attribute as a string'''
@@ -651,7 +743,7 @@ accountExpires: %u
     def host_dns_name(self):
         """return the DNS name of this host"""
         res = self.search(base='', scope=ldb.SCOPE_BASE, attrs=['dNSHostName'])
-        return res[0]['dNSHostName'][0]
+        return str(res[0]['dNSHostName'][0])
 
     def domain_dns_name(self):
         """return the DNS name of the domain root"""
@@ -672,6 +764,15 @@ accountExpires: %u
     def set_schema_from_ldb(self, ldb_conn, write_indices_and_attributes=True):
         dsdb._dsdb_set_schema_from_ldb(self, ldb_conn, write_indices_and_attributes)
 
+    def set_schema_update_now(self):
+        ldif = """
+dn:
+changetype: modify
+add: schemaUpdateNow
+schemaUpdateNow: 1
+"""
+        self.modify_ldif(ldif)
+
     def dsdb_DsReplicaAttribute(self, ldb, ldap_display_name, ldif_elements):
         '''convert a list of attribute values to a DRSUAPI DsReplicaAttribute'''
         return dsdb._dsdb_DsReplicaAttribute(ldb, ldap_display_name, ldif_elements)
@@ -688,7 +789,7 @@ accountExpires: %u
         """
         if len(self.hash_oid_name.keys()) == 0:
             self._populate_oid_attid()
-        if self.hash_oid_name.has_key(self.get_oid_from_attid(attid)):
+        if self.get_oid_from_attid(attid) in self.hash_oid_name:
             return self.hash_oid_name[self.get_oid_from_attid(attid)]
         else:
             return None
@@ -701,9 +802,9 @@ accountExpires: %u
         """
         self.hash_oid_name = {}
         res = self.search(expression="objectClass=attributeSchema",
-                           controls=["search_options:1:2"],
-                           attrs=["attributeID",
-                           "lDAPDisplayName"])
+                          controls=["search_options:1:2"],
+                          attrs=["attributeID",
+                                 "lDAPDisplayName"])
         if len(res) > 0:
             for e in res:
                 strDisplay = str(e.get("lDAPDisplayName"))
@@ -720,36 +821,36 @@ accountExpires: %u
         """
 
         res = self.search(expression="distinguishedName=%s" % dn,
-                            scope=ldb.SCOPE_SUBTREE,
-                            controls=["search_options:1:2"],
-                            attrs=["replPropertyMetaData"])
+                          scope=ldb.SCOPE_SUBTREE,
+                          controls=["search_options:1:2"],
+                          attrs=["replPropertyMetaData"])
         if len(res) == 0:
             return None
 
         repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
-                            str(res[0]["replPropertyMetaData"]))
+                          res[0]["replPropertyMetaData"][0])
         ctr = repl.ctr
         if len(self.hash_oid_name.keys()) == 0:
             self._populate_oid_attid()
         for o in ctr.array:
             # Search for Description
             att_oid = self.get_oid_from_attid(o.attid)
-            if self.hash_oid_name.has_key(att_oid) and\
+            if att_oid in self.hash_oid_name and\
                att.lower() == self.hash_oid_name[att_oid].lower():
                 return o.version
         return None
 
     def set_attribute_replmetadata_version(self, dn, att, value,
-            addifnotexist=False):
+                                           addifnotexist=False):
         res = self.search(expression="distinguishedName=%s" % dn,
-                            scope=ldb.SCOPE_SUBTREE,
-                            controls=["search_options:1:2"],
-                            attrs=["replPropertyMetaData"])
+                          scope=ldb.SCOPE_SUBTREE,
+                          controls=["search_options:1:2"],
+                          attrs=["replPropertyMetaData"])
         if len(res) == 0:
             return None
 
         repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
-                            str(res[0]["replPropertyMetaData"]))
+                          res[0]["replPropertyMetaData"][0])
         ctr = repl.ctr
         now = samba.unix2nttime(int(time.time()))
         found = False
@@ -758,7 +859,7 @@ accountExpires: %u
         for o in ctr.array:
             # Search for Description
             att_oid = self.get_oid_from_attid(o.attid)
-            if self.hash_oid_name.has_key(att_oid) and\
+            if att_oid in self.hash_oid_name and\
                att.lower() == self.hash_oid_name[att_oid].lower():
                 found = True
                 seq = self.sequence_number(ldb.SEQ_NEXT)
@@ -768,7 +869,7 @@ accountExpires: %u
                 o.originating_usn = seq
                 o.local_usn = seq
 
-        if not found and addifnotexist and len(ctr.array) >0:
+        if not found and addifnotexist and len(ctr.array) > 0:
             o2 = drsblobs.replPropertyMetaData1()
             o2.attid = 589914
             att_oid = self.get_oid_from_attid(o2.attid)
@@ -784,13 +885,14 @@ accountExpires: %u
             ctr.count = ctr.count + 1
             ctr.array = tab
 
-        if found :
+        if found:
             replBlob = ndr_pack(repl)
             msg = ldb.Message()
             msg.dn = res[0].dn
-            msg["replPropertyMetaData"] = ldb.MessageElement(replBlob,
-                                                ldb.FLAG_MOD_REPLACE,
-                                                "replPropertyMetaData")
+            msg["replPropertyMetaData"] = \
+                ldb.MessageElement(replBlob,
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "replPropertyMetaData")
             self.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])
 
     def write_prefixes_from_schema(self):
@@ -818,6 +920,7 @@ accountExpires: %u
         return dn
 
     def set_minPwdAge(self, value):
+        value = str(value).encode('utf8')
         m = ldb.Message()
         m.dn = ldb.Dn(self, self.domain_dn())
         m["minPwdAge"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "minPwdAge")
@@ -827,30 +930,29 @@ accountExpires: %u
         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["minPwdAge"])
         if len(res) == 0:
             return None
-        elif not "minPwdAge" in res[0]:
+        elif "minPwdAge" not in res[0]:
             return None
         else:
-            return res[0]["minPwdAge"][0]
+            return int(res[0]["minPwdAge"][0])
 
     def set_maxPwdAge(self, value):
+        value = str(value).encode('utf8')
         m = ldb.Message()
         m.dn = ldb.Dn(self, self.domain_dn())
         m["maxPwdAge"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "maxPwdAge")
         self.modify(m)
 
-
     def get_maxPwdAge(self):
         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["maxPwdAge"])
         if len(res) == 0:
             return None
-        elif not "maxPwdAge" in res[0]:
+        elif "maxPwdAge" not in res[0]:
             return None
         else:
-            return res[0]["maxPwdAge"][0]
-
-
+            return int(res[0]["maxPwdAge"][0])
 
     def set_minPwdLength(self, value):
+        value = str(value).encode('utf8')
         m = ldb.Message()
         m.dn = ldb.Dn(self, self.domain_dn())
         m["minPwdLength"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "minPwdLength")
@@ -860,12 +962,13 @@ accountExpires: %u
         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["minPwdLength"])
         if len(res) == 0:
             return None
-        elif not "minPwdLength" in res[0]:
+        elif "minPwdLength" not in res[0]:
             return None
         else:
-            return res[0]["minPwdLength"][0]
+            return int(res[0]["minPwdLength"][0])
 
     def set_pwdProperties(self, value):
+        value = str(value).encode('utf8')
         m = ldb.Message()
         m.dn = ldb.Dn(self, self.domain_dn())
         m["pwdProperties"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "pwdProperties")
@@ -875,21 +978,24 @@ accountExpires: %u
         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["pwdProperties"])
         if len(res) == 0:
             return None
-        elif not "pwdProperties" in res[0]:
+        elif "pwdProperties" not in res[0]:
             return None
         else:
-            return res[0]["pwdProperties"][0]
+            return int(res[0]["pwdProperties"][0])
 
     def set_dsheuristics(self, dsheuristics):
         m = ldb.Message()
         m.dn = ldb.Dn(self, "CN=Directory Service,CN=Windows NT,CN=Services,%s"
                       % self.get_config_basedn().get_linearized())
         if dsheuristics is not None:
-            m["dSHeuristics"] = ldb.MessageElement(dsheuristics,
-                ldb.FLAG_MOD_REPLACE, "dSHeuristics")
+            m["dSHeuristics"] = \
+                ldb.MessageElement(dsheuristics,
+                                   ldb.FLAG_MOD_REPLACE,
+                                   "dSHeuristics")
         else:
-            m["dSHeuristics"] = ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
-                "dSHeuristics")
+            m["dSHeuristics"] = \
+                ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
+                                   "dSHeuristics")
         self.modify(m)
 
     def get_dsheuristics(self):
@@ -942,12 +1048,12 @@ accountExpires: %u
     def get_dsServiceName(self):
         '''get the NTDS DN from the rootDSE'''
         res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["dsServiceName"])
-        return res[0]["dsServiceName"][0]
+        return str(res[0]["dsServiceName"][0])
 
     def get_serverName(self):
         '''get the server DN from the rootDSE'''
         res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["serverName"])
-        return res[0]["serverName"][0]
+        return str(res[0]["serverName"][0])
 
     def dns_lookup(self, dns_name, dns_partition=None):
         '''Do a DNS lookup in the database, returns the NDR database structures'''
@@ -981,7 +1087,6 @@ accountExpires: %u
         '''garbage_collect_tombstones(lp, samdb, [dn], current_time, tombstone_lifetime)
         -> (num_objects_expunged, num_links_expunged)'''
 
-
         if tombstone_lifetime is None:
             return dsdb._dsdb_garbage_collect_tombstones(self, dn,
                                                          current_time)
@@ -999,11 +1104,18 @@ accountExpires: %u
         return dsdb._dsdb_allocate_rid(self)
 
     def normalize_dn_in_domain(self, dn):
-        """return full dn of an relative dn
+        '''return a new DN expanded by adding the domain DN
+
+        If the dn is already a child of the domain DN, just
+        return it as-is.
 
         :param dn: relative dn
-        """
+        '''
         domain_dn = ldb.Dn(self, self.domain_dn())
+
+        if isinstance(dn, ldb.Dn):
+            dn = str(dn)
+
         full_dn = ldb.Dn(self, dn)
         if not full_dn.is_child_of(domain_dn):
             full_dn.add_base(domain_dn)