790cb2badc01a9e0838a08ef0b886bbfbf0d5a31
[samba.git] / source4 / scripting / python / samba / samdb.py
1 #!/usr/bin/python
2
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2008
5 # Copyright (C) Matthias Dieter Wallnoefer 2009
6 #
7 # Based on the original in EJS:
8 # Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
9 #   
10 # This program is free software; you can redistribute it and/or modify
11 # it under the terms of the GNU General Public License as published by
12 # the Free Software Foundation; either version 3 of the License, or
13 # (at your option) any later version.
14 #   
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 # GNU General Public License for more details.
19 #   
20 # You should have received a copy of the GNU General Public License
21 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
22 #
23
24 """Convenience functions for using the SAM."""
25
26 import dsdb
27 import samba
28 import ldb
29 from samba.idmap import IDmapDB
30 import pwd
31 import time
32 import base64
33
34 __docformat__ = "restructuredText"
35
36 class SamDB(samba.Ldb):
37     """The SAM database."""
38
39     def __init__(self, url=None, lp=None, modules_dir=None, session_info=None,
40                  credentials=None, flags=0, options=None, global_schema=True):
41         self.lp = lp
42         if url is None:
43             url = lp.get("sam database")
44
45         super(SamDB, self).__init__(url=url, lp=lp, modules_dir=modules_dir,
46                 session_info=session_info, credentials=credentials, flags=flags,
47                 options=options)
48
49         if global_schema:
50             dsdb.dsdb_set_global_schema(self)
51
52     def connect(self, url=None, flags=0, options=None):
53         super(SamDB, self).connect(url=self.lp.private_path(url), flags=flags,
54                 options=options)
55
56     def domain_dn(self):
57         # find the DNs for the domain
58         res = self.search(base="",
59                           scope=ldb.SCOPE_BASE,
60                           expression="(defaultNamingContext=*)",
61                           attrs=["defaultNamingContext"])
62         assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None)
63         return res[0]["defaultNamingContext"][0]
64
65     def enable_account(self, filter):
66         """Enables an account
67         
68         :param filter: LDAP filter to find the user (eg samccountname=name)
69         """
70         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
71                           expression=filter, attrs=["userAccountControl"])
72         assert(len(res) == 1)
73         user_dn = res[0].dn
74
75         userAccountControl = int(res[0]["userAccountControl"][0])
76         if (userAccountControl & 0x2):
77             userAccountControl = userAccountControl & ~0x2 # remove disabled bit
78         if (userAccountControl & 0x20):
79             userAccountControl = userAccountControl & ~0x20 # remove 'no password required' bit
80
81         mod = """
82 dn: %s
83 changetype: modify
84 replace: userAccountControl
85 userAccountControl: %u
86 """ % (user_dn, userAccountControl)
87         self.modify_ldif(mod)
88         
89     def force_password_change_at_next_login(self, filter):
90         """Forces a password change at next login
91         
92         :param filter: LDAP filter to find the user (eg samccountname=name)
93         """
94         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
95                           expression=filter, attrs=[])
96         assert(len(res) == 1)
97         user_dn = res[0].dn
98
99         mod = """
100 dn: %s
101 changetype: modify
102 replace: pwdLastSet
103 pwdLastSet: 0
104 """ % (user_dn)
105         self.modify_ldif(mod)
106
107     def newuser(self, username, unixname, password,
108                 force_password_change_at_next_login_req=False):
109         """Adds a new user
110
111         Note: This call adds also the ID mapping for winbind; therefore it works
112         *only* on SAMBA 4.
113         
114         :param username: Name of the new user
115         :param unixname: Name of the unix user to map to
116         :param password: Password for the new user
117         :param force_password_change_at_next_login_req: Force password change
118         """
119         self.transaction_start()
120         try:
121             user_dn = "CN=%s,CN=Users,%s" % (username, self.domain_dn())
122
123             # The new user record. Note the reliance on the SAMLDB module which
124             # fills in the default informations
125             self.add({"dn": user_dn, 
126                 "sAMAccountName": username,
127                 "objectClass": "user"})
128
129             # Sets the password for it
130             self.setpassword("(dn=" + user_dn + ")", password,
131               force_password_change_at_next_login_req)
132
133             # Gets the user SID (for the account mapping setup)
134             res = self.search(user_dn, scope=ldb.SCOPE_BASE,
135                               expression="objectclass=*",
136                               attrs=["objectSid"])
137             assert len(res) == 1
138             user_sid = self.schema_format_value("objectSid", res[0]["objectSid"][0])
139             
140             try:
141                 idmap = IDmapDB(lp=self.lp)
142
143                 user = pwd.getpwnam(unixname)
144
145                 # setup ID mapping for this UID
146                 idmap.setup_name_mapping(user_sid, idmap.TYPE_UID, user[2])
147
148             except KeyError:
149                 pass
150         except:
151             self.transaction_cancel()
152             raise
153         else:
154             self.transaction_commit()
155
156     def setpassword(self, filter, password, force_change_at_next_login=False):
157         """Sets the password for a user
158         
159         Note: This call uses the "userPassword" attribute to set the password.
160         This works correctly on SAMBA 4 and on Windows DCs with
161         "2003 Native" or higer domain function level.
162
163         :param filter: LDAP filter to find the user (eg samccountname=name)
164         :param password: Password for the user
165         :param force_change_at_next_login: Force password change
166         """
167         self.transaction_start()
168         try:
169             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
170                               expression=filter, attrs=[])
171             assert(len(res) == 1)
172             user_dn = res[0].dn
173
174             setpw = """
175 dn: %s
176 changetype: modify
177 replace: userPassword
178 userPassword:: %s
179 """ % (user_dn, base64.b64encode(password))
180
181             self.modify_ldif(setpw)
182
183             if force_change_at_next_login:
184                 self.force_password_change_at_next_login(
185                   "(dn=" + str(user_dn) + ")")
186
187             #  modify the userAccountControl to remove the disabled bit
188             self.enable_account(filter)
189         except:
190             self.transaction_cancel()
191             raise
192         else:
193             self.transaction_commit()
194
195     def setexpiry(self, filter, expiry_seconds, no_expiry_req=False):
196         """Sets the account expiry for a user
197         
198         :param filter: LDAP filter to find the user (eg samccountname=name)
199         :param expiry_seconds: expiry time from now in seconds
200         :param no_expiry_req: if set, then don't expire password
201         """
202         self.transaction_start()
203         try:
204             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
205                           expression=filter,
206                           attrs=["userAccountControl", "accountExpires"])
207             assert(len(res) == 1)
208             user_dn = res[0].dn
209
210             userAccountControl = int(res[0]["userAccountControl"][0])
211             accountExpires     = int(res[0]["accountExpires"][0])
212             if no_expiry_req:
213                 userAccountControl = userAccountControl | 0x10000
214                 accountExpires = 0
215             else:
216                 userAccountControl = userAccountControl & ~0x10000
217                 accountExpires = samba.unix2nttime(expiry_seconds + int(time.time()))
218
219             setexp = """
220 dn: %s
221 changetype: modify
222 replace: userAccountControl
223 userAccountControl: %u
224 replace: accountExpires
225 accountExpires: %u
226 """ % (user_dn, userAccountControl, accountExpires)
227
228             self.modify_ldif(setexp)
229         except:
230             self.transaction_cancel()
231             raise
232         else:
233             self.transaction_commit()
234
235     def set_domain_sid(self, sid):
236         """Change the domain SID used by this LDB.
237
238         :param sid: The new domain sid to use.
239         """
240         dsdb.samdb_set_domain_sid(self, sid)
241
242     def get_domain_sid(self):
243         """Read the domain SID used by this LDB.
244
245         """
246         dsdb.samdb_get_domain_sid(self)
247
248     def set_invocation_id(self, invocation_id):
249         """Set the invocation id for this SamDB handle.
250
251         :param invocation_id: GUID of the invocation id.
252         """
253         dsdb.dsdb_set_ntds_invocation_id(self, invocation_id)
254
255     def get_invocation_id(self):
256         "Get the invocation_id id"
257         return dsdb.samdb_ntds_invocation_id(self)
258
259     invocation_id = property(get_invocation_id, set_invocation_id)
260
261     domain_sid = property(get_domain_sid, set_domain_sid)
262
263     def get_ntds_GUID(self):
264         "Get the NTDS objectGUID"
265         return dsdb.samdb_ntds_objectGUID(self)
266
267     def server_site_name(self):
268         "Get the server site name"
269         return dsdb.samdb_server_site_name(self)
270
271     def load_partition_usn(self, base_dn):
272         return dsdb.dsdb_load_partition_usn(self, base_dn)