s4-net: allow a username to be displayed in setpassword errors
[amitay/samba.git] / source4 / scripting / python / samba / samdb.py
1 #!/usr/bin/python
2
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2008
5 # Copyright (C) Matthias Dieter Wallnoefer 2009
6 #
7 # Based on the original in EJS:
8 # Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
9 #   
10 # This program is free software; you can redistribute it and/or modify
11 # it under the terms of the GNU General Public License as published by
12 # the Free Software Foundation; either version 3 of the License, or
13 # (at your option) any later version.
14 #   
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 # GNU General Public License for more details.
19 #   
20 # You should have received a copy of the GNU General Public License
21 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
22 #
23
24 """Convenience functions for using the SAM."""
25
26 import dsdb
27 import samba
28 import ldb
29 from samba.idmap import IDmapDB
30 import pwd
31 import time
32 import base64
33
34 __docformat__ = "restructuredText"
35
36 class SamDB(samba.Ldb):
37     """The SAM database."""
38
39     def __init__(self, url=None, lp=None, modules_dir=None, session_info=None,
40                  credentials=None, flags=0, options=None, global_schema=True):
41         self.lp = lp
42         if url is None:
43             url = lp.get("sam database")
44
45         super(SamDB, self).__init__(url=url, lp=lp, modules_dir=modules_dir,
46                 session_info=session_info, credentials=credentials, flags=flags,
47                 options=options)
48
49         if global_schema:
50             dsdb.dsdb_set_global_schema(self)
51
52     def connect(self, url=None, flags=0, options=None):
53         super(SamDB, self).connect(url=self.lp.private_path(url), flags=flags,
54                 options=options)
55
56     def domain_dn(self):
57         # find the DNs for the domain
58         res = self.search(base="",
59                           scope=ldb.SCOPE_BASE,
60                           expression="(defaultNamingContext=*)",
61                           attrs=["defaultNamingContext"])
62         assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None)
63         return res[0]["defaultNamingContext"][0]
64
65     def enable_account(self, filter):
66         """Enables an account
67         
68         :param filter: LDAP filter to find the user (eg samccountname=name)
69         """
70         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
71                           expression=filter, attrs=["userAccountControl"])
72         assert(len(res) == 1)
73         user_dn = res[0].dn
74
75         userAccountControl = int(res[0]["userAccountControl"][0])
76         if (userAccountControl & 0x2):
77             userAccountControl = userAccountControl & ~0x2 # remove disabled bit
78         if (userAccountControl & 0x20):
79             userAccountControl = userAccountControl & ~0x20 # remove 'no password required' bit
80
81         mod = """
82 dn: %s
83 changetype: modify
84 replace: userAccountControl
85 userAccountControl: %u
86 """ % (user_dn, userAccountControl)
87         self.modify_ldif(mod)
88         
89     def force_password_change_at_next_login(self, filter):
90         """Forces a password change at next login
91         
92         :param filter: LDAP filter to find the user (eg samccountname=name)
93         """
94         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
95                           expression=filter, attrs=[])
96         assert(len(res) == 1)
97         user_dn = res[0].dn
98
99         mod = """
100 dn: %s
101 changetype: modify
102 replace: pwdLastSet
103 pwdLastSet: 0
104 """ % (user_dn)
105         self.modify_ldif(mod)
106
107     def newuser(self, username, unixname, password,
108                 force_password_change_at_next_login_req=False):
109         """Adds a new user
110
111         Note: This call adds also the ID mapping for winbind; therefore it works
112         *only* on SAMBA 4.
113         
114         :param username: Name of the new user
115         :param unixname: Name of the unix user to map to
116         :param password: Password for the new user
117         :param force_password_change_at_next_login_req: Force password change
118         """
119         self.transaction_start()
120         try:
121             user_dn = "CN=%s,CN=Users,%s" % (username, self.domain_dn())
122
123             # The new user record. Note the reliance on the SAMLDB module which
124             # fills in the default informations
125             self.add({"dn": user_dn, 
126                 "sAMAccountName": username,
127                 "objectClass": "user"})
128
129             # Sets the password for it
130             self.setpassword("(dn=" + user_dn + ")", password,
131               force_password_change_at_next_login_req)
132
133             # Gets the user SID (for the account mapping setup)
134             res = self.search(user_dn, scope=ldb.SCOPE_BASE,
135                               expression="objectclass=*",
136                               attrs=["objectSid"])
137             assert len(res) == 1
138             user_sid = self.schema_format_value("objectSid", res[0]["objectSid"][0])
139             
140             try:
141                 idmap = IDmapDB(lp=self.lp)
142
143                 user = pwd.getpwnam(unixname)
144
145                 # setup ID mapping for this UID
146                 idmap.setup_name_mapping(user_sid, idmap.TYPE_UID, user[2])
147
148             except KeyError:
149                 pass
150         except:
151             self.transaction_cancel()
152             raise
153         else:
154             self.transaction_commit()
155
156     def setpassword(self, filter, password,
157                     force_change_at_next_login=False,
158                     username=None):
159         """Sets the password for a user
160         
161         Note: This call uses the "userPassword" attribute to set the password.
162         This works correctly on SAMBA 4 and on Windows DCs with
163         "2003 Native" or higer domain function level.
164
165         :param filter: LDAP filter to find the user (eg samccountname=name)
166         :param password: Password for the user
167         :param force_change_at_next_login: Force password change
168         """
169         self.transaction_start()
170         try:
171             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
172                               expression=filter, attrs=[])
173             if len(res) == 0:
174                 print('Unable to find user "%s"' % (username or filter))
175                 raise
176             assert(len(res) == 1)
177             user_dn = res[0].dn
178
179             setpw = """
180 dn: %s
181 changetype: modify
182 replace: userPassword
183 userPassword:: %s
184 """ % (user_dn, base64.b64encode(password))
185
186             self.modify_ldif(setpw)
187
188             if force_change_at_next_login:
189                 self.force_password_change_at_next_login(
190                   "(dn=" + str(user_dn) + ")")
191
192             #  modify the userAccountControl to remove the disabled bit
193             self.enable_account(filter)
194         except:
195             self.transaction_cancel()
196             raise
197         else:
198             self.transaction_commit()
199
200     def setexpiry(self, filter, expiry_seconds, no_expiry_req=False):
201         """Sets the account expiry for a user
202         
203         :param filter: LDAP filter to find the user (eg samccountname=name)
204         :param expiry_seconds: expiry time from now in seconds
205         :param no_expiry_req: if set, then don't expire password
206         """
207         self.transaction_start()
208         try:
209             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
210                           expression=filter,
211                           attrs=["userAccountControl", "accountExpires"])
212             assert(len(res) == 1)
213             user_dn = res[0].dn
214
215             userAccountControl = int(res[0]["userAccountControl"][0])
216             accountExpires     = int(res[0]["accountExpires"][0])
217             if no_expiry_req:
218                 userAccountControl = userAccountControl | 0x10000
219                 accountExpires = 0
220             else:
221                 userAccountControl = userAccountControl & ~0x10000
222                 accountExpires = samba.unix2nttime(expiry_seconds + int(time.time()))
223
224             setexp = """
225 dn: %s
226 changetype: modify
227 replace: userAccountControl
228 userAccountControl: %u
229 replace: accountExpires
230 accountExpires: %u
231 """ % (user_dn, userAccountControl, accountExpires)
232
233             self.modify_ldif(setexp)
234         except:
235             self.transaction_cancel()
236             raise
237         else:
238             self.transaction_commit()
239
240     def set_domain_sid(self, sid):
241         """Change the domain SID used by this LDB.
242
243         :param sid: The new domain sid to use.
244         """
245         dsdb.samdb_set_domain_sid(self, sid)
246
247     def get_domain_sid(self):
248         """Read the domain SID used by this LDB.
249
250         """
251         dsdb.samdb_get_domain_sid(self)
252
253     def set_invocation_id(self, invocation_id):
254         """Set the invocation id for this SamDB handle.
255
256         :param invocation_id: GUID of the invocation id.
257         """
258         dsdb.dsdb_set_ntds_invocation_id(self, invocation_id)
259
260     def get_invocation_id(self):
261         "Get the invocation_id id"
262         return dsdb.samdb_ntds_invocation_id(self)
263
264     invocation_id = property(get_invocation_id, set_invocation_id)
265
266     domain_sid = property(get_domain_sid, set_domain_sid)
267
268     def get_ntds_GUID(self):
269         "Get the NTDS objectGUID"
270         return dsdb.samdb_ntds_objectGUID(self)
271
272     def server_site_name(self):
273         "Get the server site name"
274         return dsdb.samdb_server_site_name(self)
275
276     def load_partition_usn(self, base_dn):
277         return dsdb.dsdb_load_partition_usn(self, base_dn)