s4-dsdb: Added a helper to python SamDB for retrieving and setting minPwdAge.
[amitay/samba.git] / source4 / scripting / python / samba / samdb.py
1 #!/usr/bin/env python
2
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
5 # Copyright (C) Matthias Dieter Wallnoefer 2009
6 #
7 # Based on the original in EJS:
8 # Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
9 #   
10 # This program is free software; you can redistribute it and/or modify
11 # it under the terms of the GNU General Public License as published by
12 # the Free Software Foundation; either version 3 of the License, or
13 # (at your option) any later version.
14 #   
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 # GNU General Public License for more details.
19 #   
20 # You should have received a copy of the GNU General Public License
21 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
22 #
23
24 """Convenience functions for using the SAM."""
25
26 import samba
27 import ldb
28 import time
29 import base64
30 from samba import dsdb
31 from samba.ndr import ndr_unpack, ndr_pack
32 from samba.dcerpc import drsblobs, misc
33
34 __docformat__ = "restructuredText"
35
36 class SamDB(samba.Ldb):
37     """The SAM database."""
38
39     hash_oid_name = {}
40
41     def __init__(self, url=None, lp=None, modules_dir=None, session_info=None,
42                  credentials=None, flags=0, options=None, global_schema=True,
43                  auto_connect=True, am_rodc=None):
44         self.lp = lp
45         if not auto_connect:
46             url = None
47         elif url is None and lp is not None:
48             url = lp.get("sam database")
49
50         super(SamDB, self).__init__(url=url, lp=lp, modules_dir=modules_dir,
51                 session_info=session_info, credentials=credentials, flags=flags,
52                 options=options)
53
54         if global_schema:
55             dsdb._dsdb_set_global_schema(self)
56
57         if am_rodc is not None:
58             dsdb._dsdb_set_am_rodc(self, am_rodc)
59
60     def connect(self, url=None, flags=0, options=None):
61         if self.lp is not None:
62             url = self.lp.private_path(url)
63
64         super(SamDB, self).connect(url=url, flags=flags,
65                 options=options)
66
67     def am_rodc(self):
68         return dsdb._am_rodc(self)
69
70     def domain_dn(self):
71         return str(self.get_default_basedn())
72
73     def enable_account(self, search_filter):
74         """Enables an account
75         
76         :param search_filter: LDAP filter to find the user (eg samccountname=name)
77         """
78         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
79                           expression=search_filter, attrs=["userAccountControl"])
80         assert(len(res) == 1)
81         user_dn = res[0].dn
82
83         userAccountControl = int(res[0]["userAccountControl"][0])
84         if (userAccountControl & 0x2):
85             userAccountControl = userAccountControl & ~0x2 # remove disabled bit
86         if (userAccountControl & 0x20):
87             userAccountControl = userAccountControl & ~0x20 # remove 'no password required' bit
88
89         mod = """
90 dn: %s
91 changetype: modify
92 replace: userAccountControl
93 userAccountControl: %u
94 """ % (user_dn, userAccountControl)
95         self.modify_ldif(mod)
96         
97     def force_password_change_at_next_login(self, search_filter):
98         """Forces a password change at next login
99         
100         :param search_filter: LDAP filter to find the user (eg samccountname=name)
101         """
102         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
103                           expression=search_filter, attrs=[])
104         assert(len(res) == 1)
105         user_dn = res[0].dn
106
107         mod = """
108 dn: %s
109 changetype: modify
110 replace: pwdLastSet
111 pwdLastSet: 0
112 """ % (user_dn)
113         self.modify_ldif(mod)
114
115     def newgroup(self, groupname, groupou=None, grouptype=None,
116                  description=None, mailaddress=None, notes=None):
117         """Adds a new group with additional parameters
118
119         :param groupname: Name of the new group
120         :param grouptype: Type of the new group
121         :param description: Description of the new group
122         :param mailaddress: Email address of the new group
123         :param notes: Notes of the new group
124         """
125
126         group_dn = "CN=%s,%s,%s" % (groupname, (groupou or "CN=Users"), self.domain_dn())
127
128         # The new user record. Note the reliance on the SAMLDB module which
129         # fills in the default informations
130         ldbmessage = {"dn": group_dn,
131             "sAMAccountName": groupname,
132             "objectClass": "group"}
133
134         if grouptype is not None:
135             ldbmessage["groupType"] = "%d" % grouptype
136
137         if description is not None:
138             ldbmessage["description"] = description
139
140         if mailaddress is not None:
141             ldbmessage["mail"] = mailaddress
142
143         if notes is not None:
144             ldbmessage["info"] = notes
145
146         self.add(ldbmessage)
147
148     def deletegroup(self, groupname):
149         """Deletes a group
150
151         :param groupname: Name of the target group
152         """
153
154         groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (groupname, "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())
155         self.transaction_start()
156         try:
157             targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
158                                expression=groupfilter, attrs=[])
159             if len(targetgroup) == 0:
160                 raise Exception('Unable to find group "%s"' % groupname)
161             assert(len(targetgroup) == 1)
162             self.delete(targetgroup[0].dn)
163         except:
164             self.transaction_cancel()
165             raise
166         else:
167             self.transaction_commit()
168
169     def add_remove_group_members(self, groupname, listofmembers,
170                                   add_members_operation=True):
171         """Adds or removes group members
172
173         :param groupname: Name of the target group
174         :param listofmembers: Comma-separated list of group members
175         :param add_members_operation: Defines if its an add or remove operation
176         """
177
178         groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (groupname, "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())
179         groupmembers = listofmembers.split(',')
180
181         self.transaction_start()
182         try:
183             targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
184                                expression=groupfilter, attrs=['member'])
185             if len(targetgroup) == 0:
186                 raise Exception('Unable to find group "%s"' % groupname)
187             assert(len(targetgroup) == 1)
188
189             modified = False
190
191             addtargettogroup = """
192 dn: %s
193 changetype: modify
194 """ % (str(targetgroup[0].dn))
195
196             for member in groupmembers:
197                 targetmember = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
198                                     expression="(|(sAMAccountName=%s)(CN=%s))" % (member, member), attrs=[])
199
200                 if len(targetmember) != 1:
201                     continue
202
203                 if add_members_operation is True and (targetgroup[0].get('member') is None or str(targetmember[0].dn) not in targetgroup[0]['member']):
204                     modified = True
205                     addtargettogroup += """add: member
206 member: %s
207 """ % (str(targetmember[0].dn))
208
209                 elif add_members_operation is False and (targetgroup[0].get('member') is not None and str(targetmember[0].dn) in targetgroup[0]['member']):
210                     modified = True
211                     addtargettogroup += """delete: member
212 member: %s
213 """ % (str(targetmember[0].dn))
214
215             if modified is True:
216                 self.modify_ldif(addtargettogroup)
217
218         except:
219             self.transaction_cancel()
220             raise
221         else:
222             self.transaction_commit()
223
224     def newuser(self, username, password,
225                 force_password_change_at_next_login_req=False,
226                 useusernameascn=False, userou=None, surname=None, givenname=None, initials=None,
227                 profilepath=None, scriptpath=None, homedrive=None, homedirectory=None,
228                 jobtitle=None, department=None, company=None, description=None,
229                 mailaddress=None, internetaddress=None, telephonenumber=None,
230                 physicaldeliveryoffice=None):
231         """Adds a new user with additional parameters
232
233         :param username: Name of the new user
234         :param password: Password for the new user
235         :param force_password_change_at_next_login_req: Force password change
236         :param useusernameascn: Use username as cn rather that firstname + initials + lastname
237         :param userou: Object container (without domainDN postfix) for new user
238         :param surname: Surname of the new user
239         :param givenname: First name of the new user
240         :param initials: Initials of the new user
241         :param profilepath: Profile path of the new user
242         :param scriptpath: Logon script path of the new user
243         :param homedrive: Home drive of the new user
244         :param homedirectory: Home directory of the new user
245         :param jobtitle: Job title of the new user
246         :param department: Department of the new user
247         :param company: Company of the new user
248         :param description: of the new user
249         :param mailaddress: Email address of the new user
250         :param internetaddress: Home page of the new user
251         :param telephonenumber: Phone number of the new user
252         :param physicaldeliveryoffice: Office location of the new user
253         """
254
255         displayname = ""
256         if givenname is not None:
257             displayname += givenname
258
259         if initials is not None:
260             displayname += ' %s.' % initials
261
262         if surname is not None:
263             displayname += ' %s' % surname
264
265         cn = username
266         if useusernameascn is None and displayname is not "":
267             cn = displayname
268
269         user_dn = "CN=%s,%s,%s" % (cn, (userou or "CN=Users"), self.domain_dn())
270
271         dnsdomain = ldb.Dn(self, self.domain_dn()).canonical_str().replace("/", "")
272         user_principal_name = "%s@%s" % (username, dnsdomain)
273         # The new user record. Note the reliance on the SAMLDB module which
274         # fills in the default informations
275         ldbmessage = {"dn": user_dn,
276                       "sAMAccountName": username,
277                       "userPrincipalName": user_principal_name,
278                       "objectClass": "user"}
279
280         if surname is not None:
281             ldbmessage["sn"] = surname
282
283         if givenname is not None:
284             ldbmessage["givenName"] = givenname
285
286         if displayname is not "":
287             ldbmessage["displayName"] = displayname
288             ldbmessage["name"] = displayname
289
290         if initials is not None:
291             ldbmessage["initials"] = '%s.' % initials
292
293         if profilepath is not None:
294             ldbmessage["profilePath"] = profilepath
295
296         if scriptpath is not None:
297             ldbmessage["scriptPath"] = scriptpath
298
299         if homedrive is not None:
300             ldbmessage["homeDrive"] = homedrive
301
302         if homedirectory is not None:
303             ldbmessage["homeDirectory"] = homedirectory
304
305         if jobtitle is not None:
306             ldbmessage["title"] = jobtitle
307
308         if department is not None:
309             ldbmessage["department"] = department
310
311         if company is not None:
312             ldbmessage["company"] = company
313
314         if description is not None:
315             ldbmessage["description"] = description
316
317         if mailaddress is not None:
318             ldbmessage["mail"] = mailaddress
319
320         if internetaddress is not None:
321             ldbmessage["wWWHomePage"] = internetaddress
322
323         if telephonenumber is not None:
324             ldbmessage["telephoneNumber"] = telephonenumber
325
326         if physicaldeliveryoffice is not None:
327             ldbmessage["physicalDeliveryOfficeName"] = physicaldeliveryoffice
328
329         self.transaction_start()
330         try:
331             self.add(ldbmessage)
332
333             # Sets the password for it
334             self.setpassword("(dn=" + user_dn + ")", password,
335               force_password_change_at_next_login_req)
336         except:
337             self.transaction_cancel()
338             raise
339         else:
340             self.transaction_commit()
341
342     def setpassword(self, search_filter, password,
343                     force_change_at_next_login=False,
344                     username=None):
345         """Sets the password for a user
346         
347         :param search_filter: LDAP filter to find the user (eg samccountname=name)
348         :param password: Password for the user
349         :param force_change_at_next_login: Force password change
350         """
351         self.transaction_start()
352         try:
353             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
354                               expression=search_filter, attrs=[])
355             if len(res) == 0:
356                 raise Exception('Unable to find user "%s"' % (username or search_filter))
357             assert(len(res) == 1)
358             user_dn = res[0].dn
359
360             setpw = """
361 dn: %s
362 changetype: modify
363 replace: unicodePwd
364 unicodePwd:: %s
365 """ % (user_dn, base64.b64encode(("\"" + password + "\"").encode('utf-16-le')))
366
367             self.modify_ldif(setpw)
368
369             if force_change_at_next_login:
370                 self.force_password_change_at_next_login(
371                   "(dn=" + str(user_dn) + ")")
372
373             #  modify the userAccountControl to remove the disabled bit
374             self.enable_account(search_filter)
375         except:
376             self.transaction_cancel()
377             raise
378         else:
379             self.transaction_commit()
380
381     def setexpiry(self, search_filter, expiry_seconds, no_expiry_req=False):
382         """Sets the account expiry for a user
383         
384         :param search_filter: LDAP filter to find the user (eg samccountname=name)
385         :param expiry_seconds: expiry time from now in seconds
386         :param no_expiry_req: if set, then don't expire password
387         """
388         self.transaction_start()
389         try:
390             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
391                           expression=search_filter,
392                           attrs=["userAccountControl", "accountExpires"])
393             assert(len(res) == 1)
394             user_dn = res[0].dn
395
396             userAccountControl = int(res[0]["userAccountControl"][0])
397             accountExpires     = int(res[0]["accountExpires"][0])
398             if no_expiry_req:
399                 userAccountControl = userAccountControl | 0x10000
400                 accountExpires = 0
401             else:
402                 userAccountControl = userAccountControl & ~0x10000
403                 accountExpires = samba.unix2nttime(expiry_seconds + int(time.time()))
404
405             setexp = """
406 dn: %s
407 changetype: modify
408 replace: userAccountControl
409 userAccountControl: %u
410 replace: accountExpires
411 accountExpires: %u
412 """ % (user_dn, userAccountControl, accountExpires)
413
414             self.modify_ldif(setexp)
415         except:
416             self.transaction_cancel()
417             raise
418         else:
419             self.transaction_commit()
420
421     def set_domain_sid(self, sid):
422         """Change the domain SID used by this LDB.
423
424         :param sid: The new domain sid to use.
425         """
426         dsdb._samdb_set_domain_sid(self, sid)
427
428     def get_domain_sid(self):
429         """Read the domain SID used by this LDB.
430
431         """
432         return dsdb._samdb_get_domain_sid(self)
433
434     def set_invocation_id(self, invocation_id):
435         """Set the invocation id for this SamDB handle.
436
437         :param invocation_id: GUID of the invocation id.
438         """
439         dsdb._dsdb_set_ntds_invocation_id(self, invocation_id)
440
441     def get_oid_from_attid(self, attid):
442         return dsdb._dsdb_get_oid_from_attid(self, attid)
443
444     def get_attid_from_lDAPDisplayName(self, ldap_display_name, is_schema_nc=False):
445         return dsdb._dsdb_get_attid_from_lDAPDisplayName(self, ldap_display_name, is_schema_nc)
446
447     def get_invocation_id(self):
448         "Get the invocation_id id"
449         return dsdb._samdb_ntds_invocation_id(self)
450
451     def set_ntds_settings_dn(self, ntds_settings_dn):
452         """Set the NTDS Settings DN, as would be returned on the dsServiceName rootDSE attribute
453
454         This allows the DN to be set before the database fully exists
455
456         :param ntds_settings_dn: The new DN to use
457         """
458         dsdb._samdb_set_ntds_settings_dn(self, ntds_settings_dn)
459
460     invocation_id = property(get_invocation_id, set_invocation_id)
461
462     domain_sid = property(get_domain_sid, set_domain_sid)
463
464     def get_ntds_GUID(self):
465         "Get the NTDS objectGUID"
466         return dsdb._samdb_ntds_objectGUID(self)
467
468     def server_site_name(self):
469         "Get the server site name"
470         return dsdb._samdb_server_site_name(self)
471
472     def load_partition_usn(self, base_dn):
473         return dsdb._dsdb_load_partition_usn(self, base_dn)
474
475     def set_schema(self, schema):
476         self.set_schema_from_ldb(schema.ldb)
477
478     def set_schema_from_ldb(self, ldb_conn):
479         dsdb._dsdb_set_schema_from_ldb(self, ldb_conn)
480
481     def dsdb_DsReplicaAttribute(self, ldb, ldap_display_name, ldif_elements):
482         return dsdb._dsdb_DsReplicaAttribute(ldb, ldap_display_name, ldif_elements)
483
484     def get_attribute_from_attid(self, attid):
485         """ Get from an attid the associated attribute
486
487            :param attid: The attribute id for searched attribute
488            :return: The name of the attribute associated with this id
489         """
490         if len(self.hash_oid_name.keys()) == 0:
491             self._populate_oid_attid()
492         if self.hash_oid_name.has_key(self.get_oid_from_attid(attid)):
493             return self.hash_oid_name[self.get_oid_from_attid(attid)]
494         else:
495             return None
496
497
498     def _populate_oid_attid(self):
499         """Populate the hash hash_oid_name
500
501            This hash contains the oid of the attribute as a key and
502            its display name as a value
503         """
504         self.hash_oid_name = {}
505         res = self.search(expression="objectClass=attributeSchema",
506                            controls=["search_options:1:2"],
507                            attrs=["attributeID",
508                            "lDAPDisplayName"])
509         if len(res) > 0:
510             for e in res:
511                 strDisplay = str(e.get("lDAPDisplayName"))
512                 self.hash_oid_name[str(e.get("attributeID"))] = strDisplay
513
514
515     def get_attribute_replmetadata_version(self, dn, att):
516         """ Get the version field trom the replPropertyMetaData for
517             the given field
518
519            :param dn: The on which we want to get the version
520            :param att: The name of the attribute
521            :return: The value of the version field in the replPropertyMetaData
522              for the given attribute. None if the attribute is not replicated
523         """
524
525         res = self.search(expression="dn=%s" % dn,
526                             scope=ldb.SCOPE_SUBTREE,
527                             controls=["search_options:1:2"],
528                             attrs=["replPropertyMetaData"])
529         if len(res) == 0:
530             return None
531
532         repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
533                             str(res[0]["replPropertyMetaData"]))
534         ctr = repl.ctr
535         if len(self.hash_oid_name.keys()) == 0:
536             self._populate_oid_attid()
537         for o in ctr.array:
538             # Search for Description
539             att_oid = self.get_oid_from_attid(o.attid)
540             if self.hash_oid_name.has_key(att_oid) and\
541                att.lower() == self.hash_oid_name[att_oid].lower():
542                 return o.version
543         return None
544
545
546     def set_attribute_replmetadata_version(self, dn, att, value, addifnotexist=False):
547         res = self.search(expression="dn=%s" % dn,
548                             scope=ldb.SCOPE_SUBTREE,
549                             controls=["search_options:1:2"],
550                             attrs=["replPropertyMetaData"])
551         if len(res) == 0:
552             return None
553
554         repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
555                             str(res[0]["replPropertyMetaData"]))
556         ctr = repl.ctr
557         now = samba.unix2nttime(int(time.time()))
558         found = False
559         if len(self.hash_oid_name.keys()) == 0:
560             self._populate_oid_attid()
561         for o in ctr.array:
562             # Search for Description
563             att_oid = self.get_oid_from_attid(o.attid)
564             if self.hash_oid_name.has_key(att_oid) and\
565                att.lower() == self.hash_oid_name[att_oid].lower():
566                 found = True
567                 seq = self.sequence_number(ldb.SEQ_NEXT)
568                 o.version = value
569                 o.originating_change_time = now
570                 o.originating_invocation_id = misc.GUID(self.get_invocation_id())
571                 o.originating_usn = seq
572                 o.local_usn = seq
573
574         if not found and addifnotexist and len(ctr.array) >0:
575             o2 = drsblobs.replPropertyMetaData1()
576             o2.attid = 589914
577             att_oid = self.get_oid_from_attid(o2.attid)
578             seq = self.sequence_number(ldb.SEQ_NEXT)
579             o2.version = value
580             o2.originating_change_time = now
581             o2.originating_invocation_id = misc.GUID(self.get_invocation_id())
582             o2.originating_usn = seq
583             o2.local_usn = seq
584             found = True
585             tab = ctr.array
586             tab.append(o2)
587             ctr.count = ctr.count + 1
588             ctr.array = tab
589
590         if found :
591             replBlob = ndr_pack(repl)
592             msg = ldb.Message()
593             msg.dn = res[0].dn
594             msg["replPropertyMetaData"] = ldb.MessageElement(replBlob,
595                                                 ldb.FLAG_MOD_REPLACE,
596                                                 "replPropertyMetaData")
597             self.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])
598
599
600     def write_prefixes_from_schema(self):
601         dsdb._dsdb_write_prefixes_from_schema_to_ldb(self)
602
603     def get_partitions_dn(self):
604         return dsdb._dsdb_get_partitions_dn(self)
605
606     def set_minPwdAge(self, value):
607         m = ldb.Message()
608         m.dn = ldb.Dn(self, self.domain_dn())
609         m["minPwdAge"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "minPwdAge")
610         self.modify(m)
611
612     def get_minPwdAge(self):
613         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["minPwdAge"])
614         if len(res) == 0:
615             return None
616         elif not "minPwdAge" in res[0]:
617             return None
618         else:
619             return res[0]["minPwdAge"][0]