d9d12126cf875edac46c111cd1a4bdda2a570bba
[ira/wip.git] / source4 / scripting / python / samba / samdb.py
1 #!/usr/bin/python
2
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2008
5 #
6 # Based on the original in EJS:
7 # Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
8 #   
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
13 #   
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #   
19 # You should have received a copy of the GNU General Public License
20 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22
23 """Convenience functions for using the SAM."""
24
25 import samba
26 import glue
27 import ldb
28 from samba.idmap import IDmapDB
29 import pwd
30 import time
31 import base64
32
33 __docformat__ = "restructuredText"
34
35 class SamDB(samba.Ldb):
36     """The SAM database."""
37
38     def __init__(self, url=None, session_info=None, credentials=None, 
39                  modules_dir=None, lp=None, options=None):
40         """Open the Sam Database.
41
42         :param url: URL of the database.
43         """
44         self.lp = lp
45         super(SamDB, self).__init__(session_info=session_info, credentials=credentials,
46                                     modules_dir=modules_dir, lp=lp, options=options)
47         glue.dsdb_set_global_schema(self)
48         if url:
49             self.connect(url)
50         else:
51             self.connect(lp.get("sam database"))
52
53     def connect(self, url):
54         super(SamDB, self).connect(self.lp.private_path(url))
55
56     def enable_account(self, user_dn):
57         """Enable an account.
58         
59         :param user_dn: Dn of the account to enable.
60         """
61         res = self.search(user_dn, ldb.SCOPE_BASE, None, ["userAccountControl"])
62         assert len(res) == 1
63         userAccountControl = int(res[0]["userAccountControl"][0])
64         if (userAccountControl & 0x2):
65             userAccountControl = userAccountControl & ~0x2 # remove disabled bit
66         if (userAccountControl & 0x20):
67             userAccountControl = userAccountControl & ~0x20 # remove 'no password required' bit
68
69         mod = """
70 dn: %s
71 changetype: modify
72 replace: userAccountControl
73 userAccountControl: %u
74 """ % (user_dn, userAccountControl)
75         self.modify_ldif(mod)
76
77         
78     def force_password_change_at_next_login(self, user_dn):
79         """Force a password change at next login
80         
81         :param user_dn: Dn of the account to force password change on
82         """
83         mod = """
84 dn: %s
85 changetype: modify
86 replace: pwdLastSet
87 pwdLastSet: 0
88 """ % (user_dn)
89         self.modify_ldif(mod)
90
91     def domain_dn(self):
92         # find the DNs for the domain and the domain users group
93         res = self.search("", scope=ldb.SCOPE_BASE, 
94                           expression="(defaultNamingContext=*)", 
95                           attrs=["defaultNamingContext"])
96         assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None)
97         return res[0]["defaultNamingContext"][0]
98
99     def newuser(self, username, unixname, password, force_password_change_at_next_login=False):
100         """add a new user record.
101         
102         :param username: Name of the new user.
103         :param unixname: Name of the unix user to map to.
104         :param password: Password for the new user
105         """
106         # connect to the sam 
107         self.transaction_start()
108         try:
109             domain_dn = self.domain_dn()
110             assert(domain_dn is not None)
111             user_dn = "CN=%s,CN=Users,%s" % (username, domain_dn)
112
113             #
114             #  the new user record. note the reliance on the samdb module to 
115             #  fill in a sid, guid etc
116             #
117             #  now the real work
118             self.add({"dn": user_dn, 
119                 "sAMAccountName": username,
120                 "userPassword": password,
121                 "objectClass": "user"})
122
123             res = self.search(user_dn, scope=ldb.SCOPE_BASE,
124                               expression="objectclass=*",
125                               attrs=["objectSid"])
126             assert len(res) == 1
127             user_sid = self.schema_format_value("objectSid", res[0]["objectSid"][0])
128             
129             try:
130                 idmap = IDmapDB(lp=self.lp)
131
132                 user = pwd.getpwnam(unixname)
133                 # setup ID mapping for this UID
134                 
135                 idmap.setup_name_mapping(user_sid, idmap.TYPE_UID, user[2])
136
137             except KeyError:
138                 pass
139
140             if force_password_change_at_next_login:
141                 self.force_password_change_at_next_login(user_dn)
142
143             #  modify the userAccountControl to remove the disabled bit
144             self.enable_account(user_dn)
145         except:
146             self.transaction_cancel()
147             raise
148         self.transaction_commit()
149
150     def setpassword(self, filter, password, force_password_change_at_next_login=False):
151         """Set a password on a user record
152         
153         :param filter: LDAP filter to find the user (eg samccountname=name)
154         :param password: Password for the user
155         """
156         # connect to the sam 
157         self.transaction_start()
158         try:
159             # find the DNs for the domain
160             res = self.search("", scope=ldb.SCOPE_BASE, 
161                               expression="(defaultNamingContext=*)", 
162                               attrs=["defaultNamingContext"])
163             assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None)
164             domain_dn = res[0]["defaultNamingContext"][0]
165             assert(domain_dn is not None)
166
167             res = self.search(domain_dn, scope=ldb.SCOPE_SUBTREE, 
168                               expression=filter,
169                               attrs=[])
170             assert(len(res) == 1)
171             user_dn = res[0].dn
172
173             setpw = """
174 dn: %s
175 changetype: modify
176 replace: userPassword
177 userPassword:: %s
178 """ % (user_dn, base64.b64encode(password))
179
180             self.modify_ldif(setpw)
181
182             if force_password_change_at_next_login:
183                 self.force_password_change_at_next_login(user_dn)
184
185             #  modify the userAccountControl to remove the disabled bit
186             self.enable_account(user_dn)
187         except:
188             self.transaction_cancel()
189             raise
190         self.transaction_commit()
191
192     def setexpiry(self, user, expiry_seconds, noexpiry):
193         """Set the account expiry for a user
194         
195         :param expiry_seconds: expiry time from now in seconds
196         :param noexpiry: if set, then don't expire password
197         """
198         self.transaction_start()
199         try:
200             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
201                               expression=("(samAccountName=%s)" % user),
202                               attrs=["userAccountControl", "accountExpires"])
203             assert len(res) == 1
204             userAccountControl = int(res[0]["userAccountControl"][0])
205             accountExpires     = int(res[0]["accountExpires"][0])
206             if noexpiry:
207                 userAccountControl = userAccountControl | 0x10000
208                 accountExpires = 0
209             else:
210                 userAccountControl = userAccountControl & ~0x10000
211                 accountExpires = glue.unix2nttime(expiry_seconds + int(time.time()))
212
213             mod = """
214 dn: %s
215 changetype: modify
216 replace: userAccountControl
217 userAccountControl: %u
218 replace: accountExpires
219 accountExpires: %u
220 """ % (res[0].dn, userAccountControl, accountExpires)
221             # now change the database
222             self.modify_ldif(mod)
223         except:
224             self.transaction_cancel()
225             raise
226         self.transaction_commit();
227